Class: RubyZero::Utils::Trainer
- Inherits:
-
Object
- Object
- RubyZero::Utils::Trainer
- Defined in:
- lib/rubyzero/utils/trainer.rb
Instance Method Summary collapse
-
#initialize(model, loss_function: nil, optimizer: nil) ⇒ Trainer
constructor
A new instance of Trainer.
- #train(train_data, test_data, num_epochs: 1, batch_size: 1, shuffle: true, show_graph: false, show_graph_finish: true) ⇒ Object
Constructor Details
#initialize(model, loss_function: nil, optimizer: nil) ⇒ Trainer
Returns a new instance of Trainer.
3 4 5 6 7 |
# File 'lib/rubyzero/utils/trainer.rb', line 3 def initialize(model, loss_function: nil, optimizer: nil) @model = model @optimizer = optimizer || RubyZero::NN::Optimizers::Momentum.new(@model.parameters()) @loss_function = loss_function || RubyZero::NN::Losses::MSE.new() end |
Instance Method Details
#train(train_data, test_data, num_epochs: 1, batch_size: 1, shuffle: true, show_graph: false, show_graph_finish: true) ⇒ Object
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# File 'lib/rubyzero/utils/trainer.rb', line 9 def train(train_data, test_data, num_epochs:1, batch_size:1, shuffle:true, show_graph:false, show_graph_finish:true) train_loader = RubyZero::Data::DataLoader.new(train_data, batch_size:batch_size, shuffle:shuffle) test_loader = RubyZero::Data::DataLoader.new(test_data, batch_size:batch_size, shuffle:shuffle) losses_train = [] losses_test = [] num_epochs.times do |epoch| losses_train_b = [] losses_test_b = [] train_loader.each do |input, target| @optimizer.zero_grad loss = @loss_function.call(@model.call(input), target) loss.backward() @optimizer.step() losses_train_b << loss.data[0] end test_loader.each do |input, target| loss = @loss_function.call(@model.call(input), target) losses_test_b << loss.data[0] end avg_loss_train_b = losses_train_b.reduce(:+) / losses_train_b.size avg_loss_test_b = losses_test_b.reduce(:+) / losses_test_b.size losses_train << avg_loss_train_b losses_test << avg_loss_test_b if show_graph or (show_graph_finish and epoch == num_epochs-1) clear_console() plot = UnicodePlot.lineplot((0..epoch).to_a, losses_train, name:"train loss") UnicodePlot.lineplot!(plot, (0..epoch).to_a, losses_test, name:"test loss") plot.render() puts "train loss:#{avg_loss_train_b}\n test loss:#{avg_loss_test_b}" end end return losses_train[-1] end |