Class: Rumale::Tree::BaseDecisionTree

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator
Defined in:
lib/rumale/tree/base_decision_tree.rb

Overview

BaseDecisionTree is an abstract class for implementation of decision tree-based estimator. This class is used internally.

Instance Attribute Summary

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ BaseDecisionTree

Initialize a decision tree-based estimator.



26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/rumale/tree/base_decision_tree.rb', line 26

def initialize(criterion: nil, max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil)
  @params = {}
  @params[:criterion] = criterion
  @params[:max_depth] = max_depth
  @params[:max_leaf_nodes] = max_leaf_nodes
  @params[:min_samples_leaf] = min_samples_leaf
  @params[:max_features] = max_features
  @params[:random_seed] = random_seed
  @params[:random_seed] ||= srand
  @tree = nil
  @feature_importances = nil
  @n_leaves = nil
  @rng = Random.new(@params[:random_seed])
end

Instance Method Details

#apply(x) ⇒ Numo::Int32

Return the index of the leaf that each sample reached.



45
46
47
48
# File 'lib/rumale/tree/base_decision_tree.rb', line 45

def apply(x)
  check_sample_array(x)
  Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })]
end