Class: DNN::Models::ModelEvaluator

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/models.rb

Instance Method Summary collapse

Constructor Details

#initialize(model) ⇒ ModelEvaluator

Returns a new instance of ModelEvaluator.



883
884
885
886
# File 'lib/dnn/core/models.rb', line 883

def initialize(model)
  @model = model
  @state = :none
end

Instance Method Details

#evaluating?Boolean

Check if it is currently evaluating.

Returns:

  • (Boolean)

    Returns true if currently evaluating.



926
927
928
# File 'lib/dnn/core/models.rb', line 926

def evaluating?
  @state != :none
end

#start_evaluate(x, y, batch_size: 100, need_accuracy: true) ⇒ Array

Start evaluate model and get accuracy and loss of test data.

Parameters:

  • x (Numo::SFloat)

    Input test data.

  • y (Numo::SFloat)

    Output test data.

  • batch_size (Integer) (defaults to: 100)

    Batch size used for one test.

  • need_accuracy (Boolean) (defaults to: true)

    Set true to compute the accuracy.

Returns:

  • (Array)

    Returns the test data accuracy and mean loss in the form [accuracy, mean_loss]. If accuracy is not needed returns in the form [nil, mean_loss].



895
896
897
898
899
# File 'lib/dnn/core/models.rb', line 895

def start_evaluate(x, y, batch_size: 100, need_accuracy: true)
  Utils.check_input_data_type("x", x, Xumo::SFloat)
  Utils.check_input_data_type("y", y, Xumo::SFloat)
  start_evaluate_by_iterator(Iterator.new(x, y, random: false), batch_size: batch_size, need_accuracy: need_accuracy)
end

#start_evaluate_by_iterator(test_iterator, batch_size: 100, need_accuracy: true) ⇒ Array

Start Evaluate model by iterator.

Parameters:

  • test_iterator (DNN::Iterator)

    Iterator used for testing.

  • batch_size (Integer) (defaults to: 100)

    Batch size used for one test.

  • need_accuracy (Boolean) (defaults to: true)

    Set true to compute the accuracy.

Returns:

  • (Array)

    Returns the test data accuracy and mean loss in the form [accuracy, mean_loss]. If accuracy is not needed returns in the form [nil, mean_loss].



907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
# File 'lib/dnn/core/models.rb', line 907

def start_evaluate_by_iterator(test_iterator, batch_size: 100, need_accuracy: true)
  @test_iterator = test_iterator
  @num_test_datas = test_iterator.num_datas
  @batch_size = batch_size >= @num_test_datas ? @num_test_datas : batch_size
  @need_accuracy = need_accuracy
  if @loss_func.is_a?(Array)
    @total_correct = Array.new(@loss_func.length, 0)
    @sum_loss = Array.new(@loss_func.length, 0)
  else
    @total_correct = 0
    @sum_loss = 0
  end
  @step = 1
  @max_steps = (@num_test_datas.to_f / @batch_size).ceil
  @state = :start_step
end

#updateObject

Update evaluator.



931
932
933
934
935
936
937
938
939
940
941
942
# File 'lib/dnn/core/models.rb', line 931

def update
  case @state
  when :start_step
    start_step
  when :test_step
    test_step
  when :end_step
    end_step
  when :end_evaluate
    end_evaluate
  end
end