Class: TensorFlowModel

Inherits:
VectorModel show all
Defined in:
lib/rbbt/vector/model/tensorflow.rb

Instance Attribute Summary collapse

Attributes inherited from VectorModel

#bar, #directory, #eval_model, #extract_features, #factor_levels, #features, #labels, #model_file, #names, #train_model

Instance Method Summary collapse

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #run, #save_models, #train

Constructor Details

#initialize(dir, graph = nil, epochs = 3, **compile_options) ⇒ TensorFlowModel

Returns a new instance of TensorFlowModel.



21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/rbbt/vector/model/tensorflow.rb', line 21

def initialize(dir, graph = nil, epochs = 3, **compile_options)
  @graph = graph
  @epochs = epochs
  @compile_options = compile_options

  super(dir)

  @train_model = Proc.new do |file, features, labels|
    tensorflow do 
      features = tensorflow.convert_to_tensor(features)
      labels = tensorflow.convert_to_tensor(labels)
    end
    @graph ||= keras_graph
    @graph.compile(**@compile_options)
    @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
    @graph.save(file)
  end
 
  @eval_model = Proc.new do |file, features|
    tensorflow do 
      features = tensorflow.convert_to_tensor(features)
    end
    keras do
      @graph ||= keras.models.load_model(file)
      indices = @graph.predict(features, :verbose => false).tolist()
      labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
      labels
    end
  end
end

Instance Attribute Details

#compile_optionsObject

Returns the value of attribute compile_options.



5
6
7
# File 'lib/rbbt/vector/model/tensorflow.rb', line 5

def compile_options
  @compile_options
end

#epochsObject

Returns the value of attribute epochs.



5
6
7
# File 'lib/rbbt/vector/model/tensorflow.rb', line 5

def epochs
  @epochs
end

#graphObject

Returns the value of attribute graph.



5
6
7
# File 'lib/rbbt/vector/model/tensorflow.rb', line 5

def graph
  @graph
end

Instance Method Details

#keras(&block) ⇒ Object



13
14
15
16
17
18
19
# File 'lib/rbbt/vector/model/tensorflow.rb', line 13

def keras(&block)
  RbbtPython.run "tensorflow.keras", as: 'keras' do 
    RbbtPython.run "tensorflow" do 
      RbbtPython.module_eval(&block)
    end
  end
end

#keras_graph(&block) ⇒ Object



52
53
54
# File 'lib/rbbt/vector/model/tensorflow.rb', line 52

def keras_graph(&block)
  @graph = keras(&block)
end

#tensorflow(&block) ⇒ Object



7
8
9
10
11
# File 'lib/rbbt/vector/model/tensorflow.rb', line 7

def tensorflow(&block)
  RbbtPython.run "tensorflow" do 
    RbbtPython.module_eval(&block)
  end
end