Class: HybridForest::Trees::Tree

Inherits:
Object
  • Object
show all
Defined in:
lib/hybridforest/trees/tree.rb

Direct Known Subclasses

CARTTree, ID3Tree

Instance Method Summary collapse

Constructor Details

#initialize(tree_grower:) ⇒ Tree

Creates a new Tree using the specified tree growing algorithm.



9
10
11
# File 'lib/hybridforest/trees/tree.rb', line 9

def initialize(tree_grower:)
  @tree_grower = tree_grower
end

Instance Method Details

#fit(instances) ⇒ Object

Fits a model to the given dataset instances and returns self.



16
17
18
19
20
# File 'lib/hybridforest/trees/tree.rb', line 16

def fit(instances)
  instances = HybridForest::Utils.to_dataframe(instances)
  @root = @tree_grower.grow_tree(instances)
  self
end

#inspectObject

Prints a string representation of this Tree.



37
38
39
40
41
42
43
# File 'lib/hybridforest/trees/tree.rb', line 37

def inspect
  if @root.nil?
    "Empty tree: #{super}"
  else
    @root.print_string
  end
end

#predict(instances) ⇒ Object

Predicts a label for each instance in the dataset instances and returns an array of labels.



25
26
27
28
29
30
31
32
33
34
# File 'lib/hybridforest/trees/tree.rb', line 25

def predict(instances)
  if @root.nil?
    raise Errors::InvalidStateError,
      "You must call #fit before you call #predict"
  end

  HybridForest::Utils.to_dataframe(instances).each_row.reduce([]) do |predictions, instance|
    predictions << @root.classify(instance)
  end
end