Class: Tensorflow::Keras::Metrics::Mean

Inherits:
Object
  • Object
show all
Defined in:
lib/tensorflow/keras/metrics/mean.rb

Direct Known Subclasses

SparseCategoricalAccuracy

Instance Method Summary collapse

Constructor Details

#initialize(name: nil, dtype: :float) ⇒ Mean

Returns a new instance of Mean.



5
6
7
8
9
# File 'lib/tensorflow/keras/metrics/mean.rb', line 5

def initialize(name: nil, dtype: :float)
  @dtype = dtype
  @total = Utils.add_weight(name: "total", initializer: "zeros", dtype: @dtype)
  @count = Utils.add_weight(name: "count", initializer: "zeros", dtype: @dtype)
end

Instance Method Details

#call(*args) ⇒ Object



11
12
13
# File 'lib/tensorflow/keras/metrics/mean.rb', line 11

def call(*args)
  update_state(*args)
end

#reset_statesObject



25
26
# File 'lib/tensorflow/keras/metrics/mean.rb', line 25

def reset_states
end

#resultObject



21
22
23
# File 'lib/tensorflow/keras/metrics/mean.rb', line 21

def result
  RawOps.div_no_nan(@total, Tensorflow.cast(@count, :float))
end

#update_state(values) ⇒ Object



15
16
17
18
19
# File 'lib/tensorflow/keras/metrics/mean.rb', line 15

def update_state(values)
  input = Tensorflow.cast(input, destination_dtype: @dtype)
  @total.assign_add(Math.reduce_sum(input))
  @count.assign_add(Tensorflow.cast(RawOps.size(input), @dtype))
end