Class: TensorFlow::Keras::Metrics::Mean
- Inherits:
-
Object
- Object
- TensorFlow::Keras::Metrics::Mean
- Defined in:
- lib/tensorflow/keras/metrics/mean.rb
Direct Known Subclasses
Instance Method Summary collapse
- #call(*args) ⇒ Object
-
#initialize(name: nil, dtype: :float) ⇒ Mean
constructor
A new instance of Mean.
- #reset_states ⇒ Object
- #result ⇒ Object
- #update_state(values) ⇒ Object
Constructor Details
#initialize(name: nil, dtype: :float) ⇒ Mean
Returns a new instance of Mean.
5 6 7 8 9 |
# File 'lib/tensorflow/keras/metrics/mean.rb', line 5 def initialize(name: nil, dtype: :float) @dtype = dtype @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype) @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype) end |
Instance Method Details
#call(*args) ⇒ Object
11 12 13 |
# File 'lib/tensorflow/keras/metrics/mean.rb', line 11 def call(*args) update_state(*args) end |
#reset_states ⇒ Object
26 27 |
# File 'lib/tensorflow/keras/metrics/mean.rb', line 26 def reset_states end |
#result ⇒ Object
22 23 24 |
# File 'lib/tensorflow/keras/metrics/mean.rb', line 22 def result RawOps.div_no_nan(x: @total, y: TensorFlow.cast(@count, :float)) end |
#update_state(values) ⇒ Object
15 16 17 18 19 20 |
# File 'lib/tensorflow/keras/metrics/mean.rb', line 15 def update_state(values) input = TensorFlow.convert_to_tensor(values) input = TensorFlow.cast(input, @dtype) @total.assign_add(Math.reduce_sum(input)) @count.assign_add(TensorFlow.cast(RawOps.size(input: input), @dtype)) end |