Class: TensorFlow::Keras::Metrics::Mean
- Inherits:
-
Object
- Object
- TensorFlow::Keras::Metrics::Mean
show all
- Defined in:
- lib/tensorflow/keras/metrics/mean.rb
Instance Method Summary
collapse
Constructor Details
#initialize(name: nil, dtype: :float) ⇒ Mean
Returns a new instance of Mean.
5
6
7
8
9
|
# File 'lib/tensorflow/keras/metrics/mean.rb', line 5
def initialize(name: nil, dtype: :float)
@dtype = dtype
@total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
@count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
end
|
Instance Method Details
#call(*args) ⇒ Object
11
12
13
|
# File 'lib/tensorflow/keras/metrics/mean.rb', line 11
def call(*args)
update_state(*args)
end
|
#reset_states ⇒ Object
26
27
|
# File 'lib/tensorflow/keras/metrics/mean.rb', line 26
def reset_states
end
|
#result ⇒ Object
22
23
24
|
# File 'lib/tensorflow/keras/metrics/mean.rb', line 22
def result
RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float))
end
|
#update_state(values) ⇒ Object
15
16
17
18
19
20
|
# File 'lib/tensorflow/keras/metrics/mean.rb', line 15
def update_state(values)
input = TensorFlow.convert_to_tensor(values)
input = TensorFlow.cast(input, @dtype)
@total.assign_add(Math.reduce_sum(input))
@count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype))
end
|