Class: TensorFlowModel

Inherits:
VectorModel show all
Defined in:
lib/rbbt/vector/model/tensorflow.rb

Instance Attribute Summary collapse

Attributes inherited from VectorModel

#directory, #eval_model, #extract_features, #factor_levels, #features, #labels, #model_file, #names, #train_model

Instance Method Summary collapse

Methods inherited from VectorModel

R_eval, R_run, R_train, #__load_method, #add, #add_list, #clear, #cross_validation, #eval, #eval_list, f1_metrics, #run, #save_models, #train

Constructor Details

#initialize(dir, graph = nil, epochs = 3, **compile_options) ⇒ TensorFlowModel



21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/rbbt/vector/model/tensorflow.rb', line 21

def initialize(dir, graph = nil, epochs = 3, **compile_options)
  @graph = graph
  @epochs = epochs
  @compile_options = compile_options

  super(dir)

  @train_model = Proc.new do |file, features, labels|
    tensorflow do 
      features = tensorflow.convert_to_tensor(features)
      labels = tensorflow.convert_to_tensor(labels)
    end
    @graph ||= keras_graph
    @graph.compile(**@compile_options)
    @graph.fit(features, labels, :epochs => @epochs, :verbose => true)
    @graph.save(file)
  end
 
  @eval_model = Proc.new do |file, features|
    tensorflow do 
      features = tensorflow.convert_to_tensor(features)
    end
    keras do
      @graph ||= keras.models.load_model(file)
      indices = @graph.predict(features, :verbose => false).tolist()
      labels = indices.collect{|p| p.length > 1 ? p.index(p.max): p.first }
      labels
    end
  end
end

Instance Attribute Details

#compile_optionsObject

Returns the value of attribute compile_options.



5
6
7
# File 'lib/rbbt/vector/model/tensorflow.rb', line 5

def compile_options
  @compile_options
end

#epochsObject

Returns the value of attribute epochs.



5
6
7
# File 'lib/rbbt/vector/model/tensorflow.rb', line 5

def epochs
  @epochs
end

#graphObject

Returns the value of attribute graph.



5
6
7
# File 'lib/rbbt/vector/model/tensorflow.rb', line 5

def graph
  @graph
end

Instance Method Details

#keras(&block) ⇒ Object



13
14
15
16
17
18
19
# File 'lib/rbbt/vector/model/tensorflow.rb', line 13

def keras(&block)
  RbbtPython.run "tensorflow.keras", as: 'keras' do 
    RbbtPython.run "tensorflow" do 
      RbbtPython.module_eval(&block)
    end
  end
end

#keras_graph(&block) ⇒ Object



52
53
54
# File 'lib/rbbt/vector/model/tensorflow.rb', line 52

def keras_graph(&block)
  @graph = keras(&block)
end

#tensorflow(&block) ⇒ Object



7
8
9
10
11
# File 'lib/rbbt/vector/model/tensorflow.rb', line 7

def tensorflow(&block)
  RbbtPython.run "tensorflow" do 
    RbbtPython.module_eval(&block)
  end
end