Module: Torchrb::NN::Basic

Defined in:
lib/torchrb/nn/basic.rb

Instance Method Summary collapse

Instance Method Details

#define_nnObject



3
4
5
6
7
8
9
10
11
12
13
# File 'lib/torchrb/nn/basic.rb', line 3

def define_nn
  input_layer = 1
  interm_layer = 80
  output_layer = model.classes.size
  torch.eval(<<-EOF, __FILE__, __LINE__).to_h
    net = nn.Sequential()
    net:add(nn.Linear(#{input_layer}, #{interm_layer}))
    net:add(nn.Linear(#{interm_layer}, #{output_layer}))
    net:add(nn.LogSoftMax())
  EOF
end