Class: TorchModel

Inherits:
PythonModel show all
Defined in:
lib/rbbt/vector/model/torch.rb,
lib/rbbt/vector/model/torch/helpers.rb,
lib/rbbt/vector/model/torch/dataloader.rb,
lib/rbbt/vector/model/torch/introspection.rb,
lib/rbbt/vector/model/torch/load_and_save.rb

Direct Known Subclasses

HuggingfaceModel, PytorchLightningModel

Defined Under Namespace

Modules: Tensor

Instance Attribute Summary collapse

Attributes inherited from PythonModel

#python_class, #python_module

Attributes inherited from VectorModel

#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model_options, #model_path, #names, #post_process, #train_model

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train

Constructor Details

#initializeTorchModel

Returns a new instance of TorchModel.



7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# File 'lib/rbbt/vector/model/torch.rb', line 7

def initialize(...)
  TorchModel.init_python
  super(...)
  @training_args = model_options[:training_args] || {}

  init_model do
    model = TorchModel.load_architecture(model_path) 
    if model.nil?
      RbbtPython.add_path @directory 
      RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
    else
      TorchModel.load_state(model, model_path)
    end
  end

  eval_model do |features,list=false|
    init
    @device ||= TorchModel.device(model_options)
    @dtype ||= TorchModel.dtype(model_options)
    model.to(@device)

    tensor = list ? TorchModel.tensor(features, @device, @dtype) : TorchModel.tensor([features], @device, @dtype)

    loss, res = model.call(tensor)

    res = loss if res.nil?

    res = TorchModel::Tensor.setup(list ? res : res[0])

    res
  end

  train_model do |features,labels|
    init
    @device ||= TorchModel.device(model_options)
    @dtype ||= TorchModel.dtype(model_options)
    model.to(@device)
    @optimizer ||= TorchModel.optimizer(model, training_args)
    epochs = training_args[:epochs] || 3

    inputs = TorchModel.tensor(features, @device, @dtype)
    #target = TorchModel.tensor(labels.collect{|v| [v] }, @device, @dtype)
    target = TorchModel.tensor(labels, @device, @dtype)

    Log::ProgressBar.with_bar epochs, :desc => "Training" do |bar|
      epochs.times do |i|
        @optimizer.zero_grad()
        outputs = model.call(inputs)
        outputs = outputs.squeeze() if target.dim() == 1
        loss = criterion.call(outputs, target)
        loss.backward()
        @optimizer.step
        Log.debug "Epoch #{i}, loss #{loss}"
        bar.tick
      end
    end
    TorchModel.save_architecture(model, model_path) if @directory
    TorchModel.save_state(model, model_path) if @directory
  end
end

Instance Attribute Details

#criterionObject

Returns the value of attribute criterion.



5
6
7
# File 'lib/rbbt/vector/model/torch.rb', line 5

def criterion
  @criterion
end

#modelObject

Returns the value of attribute model.



5
6
7
# File 'lib/rbbt/vector/model/torch.rb', line 5

def model
  @model
end

#optimizerObject

Returns the value of attribute optimizer.



5
6
7
# File 'lib/rbbt/vector/model/torch.rb', line 5

def optimizer
  @optimizer
end

#training_argsObject

Returns the value of attribute training_args.



5
6
7
# File 'lib/rbbt/vector/model/torch.rb', line 5

def training_args
  @training_args
end

Class Method Details

.device(model_options) ⇒ Object



26
27
28
29
30
31
32
33
34
35
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 26

def self.device(model_options)
  case model_options[:device]
  when String, Symbol
    RbbtPython.torch.device(model_options[:device].to_s)
  when nil
    RbbtPython.rbbt_dm.util.device()
  else
      model_options[:device]
  end
end

.dtype(model_options) ⇒ Object



37
38
39
40
41
42
43
44
45
46
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 37

def self.dtype(model_options)
  case model_options[:dtype]
  when String, Symbol
    RbbtPython.torch.call(model_options[:dtype])
  when nil
    nil
  else
    model_options[:dtype]
  end
end

.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object



38
39
40
41
42
# File 'lib/rbbt/vector/model/torch/dataloader.rb', line 38

def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
  tsv = feature_tsv(elements, labels, class_labels)
  Open.write(tsv_dataset_file, tsv.to_s)
  tsv_dataset_file
end

.feature_tsv(elements, labels = nil, class_labels = nil) ⇒ Object



2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/rbbt/vector/model/torch/dataloader.rb', line 2

def self.feature_tsv(elements, labels = nil, class_labels = nil)
  tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
  if labels
    tsv.fields = tsv.fields + ["label"]
    labels = case class_labels
             when Array
               labels.collect{|l| class_labels.index l}
             when Hash
               inverse_class_labels = {}
               class_labels.each{|c,l| inverse_class_labels[l] = c }
               labels.collect{|l| inverse_class_labels[l]}
             else
               labels
             end
    elements.zip(labels).each_with_index do |p,i|
      features, label = p
      id = i
      if Array === features
        tsv[id] = features + [label]
      else
        tsv[id] = [features, label]
      end
    end
  else
    elements.each_with_index do |features,i|
      id = i
      if Array === features
        tsv[id] = features
      else
        tsv[id] = [features]
      end
    end
  end
  tsv
end

.freeze(layer) ⇒ Object



16
17
18
19
20
21
22
23
24
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 16

def self.freeze(layer)
  begin
    PyCall.getattr(layer, :weight).requires_grad = false
  rescue
  end
  RbbtPython.iterate(layer.children) do |layer|
    freeze(layer)
  end
end

.freeze_layer(model, layer) ⇒ Object



25
26
27
28
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 25

def self.freeze_layer(model, layer)
  layer = get_layer(model, layer)
  freeze(layer)
end

.get_layer(model, layer = nil) ⇒ Object



2
3
4
5
6
7
8
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 2

def self.get_layer(model, layer = nil)
  if layer.nil?
    model
  else
    layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
  end
end

.get_weights(model, layer = nil) ⇒ Object



11
12
13
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 11

def self.get_weights(model, layer = nil)
  Tensor.setup PyCall.getattr(get_layer(model, layer), :weight)
end

.init_pythonObject



11
12
13
14
15
16
17
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 11

def self.init_python
  RbbtPython.pyimport :torch
  RbbtPython.pyimport :rbbt
  RbbtPython.pyimport :rbbt_dm
  RbbtPython.pyfrom :rbbt_dm, import: :util
  RbbtPython.pyfrom :torch, import: :nn
end

.load_architecture(model_path) ⇒ Object



24
25
26
27
28
29
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 24

def self.load_architecture(model_path)
  model_architecture = model_architecture(model_path)
  return unless Open.exists?(model_architecture)
  Log.debug "Loading model architecture from #{model_architecture}"
  RbbtPython.torch.load(model_architecture)
end

.load_state(model, model_path) ⇒ Object



11
12
13
14
15
16
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 11

def self.load_state(model, model_path)
  return model unless Open.exists?(model_path)
  Log.debug "Loading model state from #{model_path}"
  model.load_state_dict(RbbtPython.torch.load(model_path))
  model
end

.model_architecture(model_path) ⇒ Object



2
3
4
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 2

def self.model_architecture(model_path)
  model_path + '.architecture'
end

.optimizer(model, training_args) ⇒ Object



19
20
21
22
23
24
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 19

def self.optimizer(model, training_args)
  begin
    learning_rate = training_args[:learning_rate] || 0.01
    RbbtPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
  end
end

.save_architecture(model, model_path) ⇒ Object



18
19
20
21
22
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 18

def self.save_architecture(model, model_path)
  model_architecture = model_architecture(model_path)
  Log.debug "Saving model architecture into #{model_architecture}"
  RbbtPython.torch.save(model, model_architecture)
end

.save_state(model, model_path) ⇒ Object



6
7
8
9
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 6

def self.save_state(model, model_path)
  Log.debug "Saving model state into #{model_path}"
  RbbtPython.torch.save(model.state_dict(), model_path)
end

.tensor(obj, device, dtype) ⇒ Object



48
49
50
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 48

def self.tensor(obj, device, dtype)
  RbbtPython.torch.tensor(obj, dtype: dtype, device: device)
end

.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object



44
45
46
47
48
49
50
51
52
53
54
55
56
# File 'lib/rbbt/vector/model/torch/dataloader.rb', line 44

def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
  elements = elements.collect{|e| e.gsub("\n", ' ') }
  tsv = feature_tsv(elements, labels, class_labels)
  if labels.nil?
    tsv.fields[0] = "text"
    tsv.type = :single
  else
    tsv.fields[0] = "text"
    tsv.type = :list
  end
  Open.write(tsv_dataset_file, tsv.to_s)
  tsv_dataset_file
end

Instance Method Details

#freeze_layerObject



29
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 29

def freeze_layer(...); TorchModel.freeze_layer(model, ...); end

#get_layerObject



9
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 9

def get_layer(...); TorchModel.get_layer(model, ...); end

#get_weightsObject



14
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 14

def get_weights(...); TorchModel.get_weights(model, ...); end