Class: TorchModel
- Inherits:
-
PythonModel
show all
- Defined in:
- lib/rbbt/vector/model/torch.rb,
lib/rbbt/vector/model/torch/helpers.rb,
lib/rbbt/vector/model/torch/dataloader.rb,
lib/rbbt/vector/model/torch/introspection.rb,
lib/rbbt/vector/model/torch/load_and_save.rb
Defined Under Namespace
Modules: Tensor
Instance Attribute Summary collapse
Attributes inherited from PythonModel
#python_class, #python_module
Attributes inherited from VectorModel
#balance, #bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #init_model, #labels, #model_options, #model_path, #names, #post_process, #train_model
Class Method Summary
collapse
-
.device(model_options) ⇒ Object
-
.dtype(model_options) ⇒ Object
-
.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object
-
.feature_tsv(elements, labels = nil, class_labels = nil) ⇒ Object
-
.freeze(layer) ⇒ Object
-
.freeze_layer(model, layer) ⇒ Object
-
.get_layer(model, layer = nil) ⇒ Object
-
.get_weights(model, layer = nil) ⇒ Object
-
.init_python ⇒ Object
-
.load_architecture(model_path) ⇒ Object
-
.load_state(model, model_path) ⇒ Object
-
.model_architecture(model_path) ⇒ Object
-
.optimizer(model, training_args) ⇒ Object
-
.save_architecture(model, model_path) ⇒ Object
-
.save_state(model, model_path) ⇒ Object
-
.tensor(obj, device, dtype) ⇒ Object
-
.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object
Instance Method Summary
collapse
Methods inherited from VectorModel
R_eval, R_run, R_train, #__load_method, #add, #add_list, #balance_labels, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #init, #run, #save_models, #train
Constructor Details
Returns a new instance of TorchModel.
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
|
# File 'lib/rbbt/vector/model/torch.rb', line 7
def initialize(...)
TorchModel.init_python
super(...)
@training_args = model_options[:training_args] || {}
init_model do
model = TorchModel.load_architecture(model_path)
if model.nil?
RbbtPython.add_path @directory
RbbtPython.class_new_obj(@python_module, @python_class, **model_options)
else
TorchModel.load_state(model, model_path)
end
end
eval_model do |features,list=false|
init
@device ||= TorchModel.device(model_options)
@dtype ||= TorchModel.dtype(model_options)
model.to(@device)
tensor = list ? TorchModel.tensor(features, @device, @dtype) : TorchModel.tensor([features], @device, @dtype)
loss, res = model.call(tensor)
res = loss if res.nil?
res = TorchModel::Tensor.setup(list ? res : res[0])
res
end
train_model do |features,labels|
init
@device ||= TorchModel.device(model_options)
@dtype ||= TorchModel.dtype(model_options)
model.to(@device)
@optimizer ||= TorchModel.optimizer(model, training_args)
epochs = training_args[:epochs] || 3
inputs = TorchModel.tensor(features, @device, @dtype)
target = TorchModel.tensor(labels, @device, @dtype)
Log::ProgressBar.with_bar epochs, :desc => "Training" do |bar|
epochs.times do |i|
@optimizer.zero_grad()
outputs = model.call(inputs)
outputs = outputs.squeeze() if target.dim() == 1
loss = criterion.call(outputs, target)
loss.backward()
@optimizer.step
Log.debug "Epoch #{i}, loss #{loss}"
bar.tick
end
end
TorchModel.save_architecture(model, model_path) if @directory
TorchModel.save_state(model, model_path) if @directory
end
end
|
Instance Attribute Details
#criterion ⇒ Object
Returns the value of attribute criterion.
5
6
7
|
# File 'lib/rbbt/vector/model/torch.rb', line 5
def criterion
@criterion
end
|
#model ⇒ Object
Returns the value of attribute model.
5
6
7
|
# File 'lib/rbbt/vector/model/torch.rb', line 5
def model
@model
end
|
#optimizer ⇒ Object
Returns the value of attribute optimizer.
5
6
7
|
# File 'lib/rbbt/vector/model/torch.rb', line 5
def optimizer
@optimizer
end
|
#training_args ⇒ Object
Returns the value of attribute training_args.
5
6
7
|
# File 'lib/rbbt/vector/model/torch.rb', line 5
def training_args
@training_args
end
|
Class Method Details
.device(model_options) ⇒ Object
26
27
28
29
30
31
32
33
34
35
|
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 26
def self.device(model_options)
case model_options[:device]
when String, Symbol
RbbtPython.torch.device(model_options[:device].to_s)
when nil
RbbtPython.rbbt_dm.util.device()
else
model_options[:device]
end
end
|
.dtype(model_options) ⇒ Object
37
38
39
40
41
42
43
44
45
46
|
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 37
def self.dtype(model_options)
case model_options[:dtype]
when String, Symbol
RbbtPython.torch.call(model_options[:dtype])
when nil
nil
else
model_options[:dtype]
end
end
|
.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object
38
39
40
41
42
|
# File 'lib/rbbt/vector/model/torch/dataloader.rb', line 38
def self.feature_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
tsv = feature_tsv(elements, labels, class_labels)
Open.write(tsv_dataset_file, tsv.to_s)
tsv_dataset_file
end
|
.feature_tsv(elements, labels = nil, class_labels = nil) ⇒ Object
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
|
# File 'lib/rbbt/vector/model/torch/dataloader.rb', line 2
def self.feature_tsv(elements, labels = nil, class_labels = nil)
tsv = TSV.setup({}, :key_field => "ID", :fields => ["features"], :type => :flat)
if labels
tsv.fields = tsv.fields + ["label"]
labels = case class_labels
when Array
labels.collect{|l| class_labels.index l}
when Hash
inverse_class_labels = {}
class_labels.each{|c,l| inverse_class_labels[l] = c }
labels.collect{|l| inverse_class_labels[l]}
else
labels
end
elements.zip(labels).each_with_index do |p,i|
features, label = p
id = i
if Array === features
tsv[id] = features + [label]
else
tsv[id] = [features, label]
end
end
else
elements.each_with_index do |features,i|
id = i
if Array === features
tsv[id] = features
else
tsv[id] = [features]
end
end
end
tsv
end
|
.freeze(layer) ⇒ Object
16
17
18
19
20
21
22
23
24
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 16
def self.freeze(layer)
begin
PyCall.getattr(layer, :weight).requires_grad = false
rescue
end
RbbtPython.iterate(layer.children) do |layer|
freeze(layer)
end
end
|
.freeze_layer(model, layer) ⇒ Object
25
26
27
28
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 25
def self.freeze_layer(model, layer)
layer = get_layer(model, layer)
freeze(layer)
end
|
.get_layer(model, layer = nil) ⇒ Object
2
3
4
5
6
7
8
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 2
def self.get_layer(model, layer = nil)
if layer.nil?
model
else
layer.split(".").inject(model){|acc,l| PyCall.getattr(acc, l.to_sym) }
end
end
|
.get_weights(model, layer = nil) ⇒ Object
11
12
13
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 11
def self.get_weights(model, layer = nil)
Tensor.setup PyCall.getattr(get_layer(model, layer), :weight)
end
|
.init_python ⇒ Object
11
12
13
14
15
16
17
|
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 11
def self.init_python
RbbtPython.pyimport :torch
RbbtPython.pyimport :rbbt
RbbtPython.pyimport :rbbt_dm
RbbtPython.pyfrom :rbbt_dm, import: :util
RbbtPython.pyfrom :torch, import: :nn
end
|
.load_architecture(model_path) ⇒ Object
24
25
26
27
28
29
|
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 24
def self.load_architecture(model_path)
model_architecture = model_architecture(model_path)
return unless Open.exists?(model_architecture)
Log.debug "Loading model architecture from #{model_architecture}"
RbbtPython.torch.load(model_architecture)
end
|
.load_state(model, model_path) ⇒ Object
11
12
13
14
15
16
|
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 11
def self.load_state(model, model_path)
return model unless Open.exists?(model_path)
Log.debug "Loading model state from #{model_path}"
model.load_state_dict(RbbtPython.torch.load(model_path))
model
end
|
.model_architecture(model_path) ⇒ Object
2
3
4
|
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 2
def self.model_architecture(model_path)
model_path + '.architecture'
end
|
.optimizer(model, training_args) ⇒ Object
19
20
21
22
23
24
|
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 19
def self.optimizer(model, training_args)
begin
learning_rate = training_args[:learning_rate] || 0.01
RbbtPython.torch.optim.SGD.new(model.parameters(), lr: learning_rate)
end
end
|
.save_architecture(model, model_path) ⇒ Object
18
19
20
21
22
|
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 18
def self.save_architecture(model, model_path)
model_architecture = model_architecture(model_path)
Log.debug "Saving model architecture into #{model_architecture}"
RbbtPython.torch.save(model, model_architecture)
end
|
.save_state(model, model_path) ⇒ Object
6
7
8
9
|
# File 'lib/rbbt/vector/model/torch/load_and_save.rb', line 6
def self.save_state(model, model_path)
Log.debug "Saving model state into #{model_path}"
RbbtPython.torch.save(model.state_dict(), model_path)
end
|
.tensor(obj, device, dtype) ⇒ Object
48
49
50
|
# File 'lib/rbbt/vector/model/torch/helpers.rb', line 48
def self.tensor(obj, device, dtype)
RbbtPython.torch.tensor(obj, dtype: dtype, device: device)
end
|
.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil) ⇒ Object
44
45
46
47
48
49
50
51
52
53
54
55
56
|
# File 'lib/rbbt/vector/model/torch/dataloader.rb', line 44
def self.text_dataset(tsv_dataset_file, elements, labels = nil, class_labels = nil)
elements = elements.collect{|e| e.gsub("\n", ' ') }
tsv = feature_tsv(elements, labels, class_labels)
if labels.nil?
tsv.fields[0] = "text"
tsv.type = :single
else
tsv.fields[0] = "text"
tsv.type = :list
end
Open.write(tsv_dataset_file, tsv.to_s)
tsv_dataset_file
end
|
Instance Method Details
#freeze_layer ⇒ Object
29
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 29
def freeze_layer(...); TorchModel.freeze_layer(model, ...); end
|
#get_layer ⇒ Object
9
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 9
def get_layer(...); TorchModel.get_layer(model, ...); end
|
#get_weights ⇒ Object
14
|
# File 'lib/rbbt/vector/model/torch/introspection.rb', line 14
def get_weights(...); TorchModel.get_weights(model, ...); end
|