Class: DNN::Models::ModelTrainer

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/models.rb

Instance Method Summary collapse

Constructor Details

#initialize(model) ⇒ ModelTrainer

Returns a new instance of ModelTrainer.



632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
# File 'lib/dnn/core/models.rb', line 632

def initialize(model)
  @model = model
  @state = :none
  @initial_epoch = 1
  @step = 1
  @max_steps = 1
  @train_iterator = nil
  @max_epochs = 1
  @batch_size = 1
  @epoch = 1
  @test = nil
  @verbose = false
  @need_accuracy = false
  @io = nil
  @num_train_datas = 0
end

Instance Method Details

#start_train(x, y, epochs, batch_size: 1, initial_epoch: 1, test: nil, verbose: true, need_accuracy: true, io: $stdout) ⇒ Object

Start training. Setup the model before use this method.

Parameters:

  • x (Numo::SFloat)

    Input training data.

  • y (Numo::SFloat)

    Output training data.

  • epochs (Integer)

    Number of training.

  • batch_size (Integer) (defaults to: 1)

    Batch size used for one training.

  • initial_epoch (Integer) (defaults to: 1)

    Initial epoch.

  • test (Array | NilClass) (defaults to: nil)

    If you to test the model for every 1 epoch, specify [x_test, y_test]. Don’t test to the model, specify nil.

  • verbose (Boolean) (defaults to: true)

    Set true to display the log. If false is set, the log is not displayed.

  • need_accuracy (Boolean) (defaults to: true)

    Set true to compute the accuracy.

  • io (IO) (defaults to: $stdout)

    Specifies the IO object to use for logging.



661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
# File 'lib/dnn/core/models.rb', line 661

def start_train(x, y, epochs,
                batch_size: 1,
                initial_epoch: 1,
                test: nil,
                verbose: true,
                need_accuracy: true,
                io: $stdout)
  Utils.check_input_data_type("x", x, Xumo::SFloat)
  Utils.check_input_data_type("y", y, Xumo::SFloat)
  train_iterator = Iterator.new(x, y)
  start_train_by_iterator(train_iterator, epochs,
                          batch_size: batch_size,
                          initial_epoch: initial_epoch,
                          test: test,
                          verbose: verbose,
                          need_accuracy: need_accuracy,
                          io: io)
end

#start_train_by_iterator(train_iterator, epochs, batch_size: 1, initial_epoch: 1, test: nil, verbose: true, need_accuracy: true, io: $stdout) ⇒ Object

Start training by iterator. Setup the model before use this method.

Parameters:

  • train_iterator (DNN::Iterator)

    Iterator used for training.

  • epochs (Integer)

    Number of training.

  • batch_size (Integer) (defaults to: 1)

    Batch size used for one training.

  • initial_epoch (Integer) (defaults to: 1)

    Initial epoch.

  • test (Array | NilClass) (defaults to: nil)

    If you to test the model for every 1 epoch, specify [x_test, y_test]. Don’t test to the model, specify nil.

  • verbose (Boolean) (defaults to: true)

    Set true to display the log. If false is set, the log is not displayed.

  • need_accuracy (Boolean) (defaults to: true)

    Set true to compute the accuracy.

  • io (IO) (defaults to: $stdout)

    Specifies the IO object to use for logging.

Raises:



691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
# File 'lib/dnn/core/models.rb', line 691

def start_train_by_iterator(train_iterator, epochs,
                            batch_size: 1,
                            initial_epoch: 1,
                            test: nil,
                            verbose: true,
                            need_accuracy: true,
                            io: $stdout)
  raise DNNError, "The model is not optimizer setup complete." unless @model.optimizer
  raise DNNError, "The model is not loss_func setup complete." unless @model.loss_func
  @model.check_early_stop_requested # Clear early stop request.
  @train_iterator = train_iterator
  @max_epochs = epochs
  @batch_size = batch_size
  @epoch = initial_epoch
  @test = test
  @verbose = verbose
  @need_accuracy = need_accuracy
  @io = io
  @state = :start_epoch
  @max_steps = train_iterator.max_steps(batch_size)
  @num_train_datas = train_iterator.num_usable_datas(batch_size)
  @line_first_pos = 0
  @model.call_callbacks(:before_train)
end

#training?Boolean

Check if it is currently evaluating.

Returns:

  • (Boolean)

    Returns true if currently training.



718
719
720
# File 'lib/dnn/core/models.rb', line 718

def training?
  @state != :none
end

#updateObject

Update trainer.



723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
# File 'lib/dnn/core/models.rb', line 723

def update
  case @state
  when :start_epoch
    start_epoch
  when :start_step
    start_step
  when :train_step
    train_step
  when :end_step
    end_step
  when :end_epoch
    end_epoch
  when :start_evaluate
    start_evaluate
  when :evaluating
    evaluating
  when :end_evaluate
    end_evaluate
  when :end_training
    end_training
  end
end