Class: TensorFlow::Keras::Metrics::SparseCategoricalAccuracy

Inherits:
Mean
  • Object
show all
Defined in:
lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb

Instance Method Summary collapse

Methods inherited from Mean

#call, #initialize, #reset_states, #result

Constructor Details

This class inherits a constructor from TensorFlow::Keras::Metrics::Mean

Instance Method Details

#update_state(y_true, y_pred) ⇒ Object



5
6
7
8
9
10
11
12
13
14
15
16
# File 'lib/tensorflow/keras/metrics/sparse_categorical_accuracy.rb', line 5

def update_state(y_true, y_pred)
  y_true = TensorFlow.convert_to_tensor(y_true)
  y_pred = TensorFlow.convert_to_tensor(y_pred)

  y_pred = RawOps.arg_max(input: y_pred, dimension: -1)

  if y_pred.dtype != y_true.dtype
    y_pred = TensorFlow.cast(y_pred, y_true.dtype)
  end

  super(Math.equal(y_true, y_pred))
end