Class: SVMKit::Ensemble::RandomForestClassifier

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator, Base::Classifier
Defined in:
lib/svmkit/ensemble/random_forest_classifier.rb

Overview

RandomForestClassifier is a class that implements random forest for classification.

Examples:

estimator =
  SVMKit::Ensemble::RandomForestClassifier.new(
    n_estimators: 10, criterion: 'gini', max_depth: 3, max_leaf_nodes: 10, min_samples_leaf: 5, random_seed: 1)
estimator.fit(training_samples, traininig_labels)
results = estimator.predict(testing_samples)

Instance Attribute Summary collapse

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Methods included from Base::Classifier

#score

Constructor Details

#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ RandomForestClassifier

Create a new classifier with random forest.



51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 51

def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1,
               max_features: nil, random_seed: nil)
  @params = {}
  @params[:n_estimators] = n_estimators
  @params[:criterion] = criterion
  @params[:max_depth] = max_depth
  @params[:max_leaf_nodes] = max_leaf_nodes
  @params[:min_samples_leaf] = min_samples_leaf
  @params[:max_features] = max_features
  @params[:random_seed] = random_seed
  @params[:random_seed] ||= srand
  @rng = Random.new(@params[:random_seed])
  @estimators = nil
  @classes = nil
  @feature_importances = nil
end

Instance Attribute Details

#classesNumo::Int32 (readonly)

Return the class labels.



28
29
30
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 28

def classes
  @classes
end

#estimatorsArray<DecisionTreeClassifier> (readonly)

Return the set of estimators.



24
25
26
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 24

def estimators
  @estimators
end

#feature_importancesNumo::DFloat (readonly)

Return the importance for each feature.



32
33
34
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 32

def feature_importances
  @feature_importances
end

#rngRandom (readonly)

Return the random generator for performing random sampling in the Pegasos algorithm.



36
37
38
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 36

def rng
  @rng
end

Instance Method Details

#apply(x) ⇒ Numo::Int32

Return the index of the leaf that each sample reached.



138
139
140
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 138

def apply(x)
  Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose
end

#fit(x, y) ⇒ RandomForestClassifier

Fit the model with given training data.



73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 73

def fit(x, y)
  # Initialize some variables.
  n_samples, n_features = x.shape
  @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer)
  @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min
  @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
  # Construct forest.
  @estimators = Array.new(@params[:n_estimators]) do |_n|
    tree = Tree::DecisionTreeClassifier.new(
      criterion: @params[:criterion], max_depth: @params[:max_depth],
      max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf],
      max_features: @params[:max_features], random_seed: @params[:random_seed]
    )
    bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) }
    tree.fit(x[bootstrap_ids, true], y[bootstrap_ids])
  end
  # Calculate feature importances.
  @feature_importances = Numo::DFloat.zeros(n_features)
  @estimators.each { |tree| @feature_importances += tree.feature_importances }
  @feature_importances /= @feature_importances.sum
  self
end

#marshal_dumpHash

Dump marshal data.



144
145
146
147
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 144

def marshal_dump
  { params: @params, estimators: @estimators, classes: @classes,
    feature_importances: @feature_importances, rng: @rng }
end

#marshal_load(obj) ⇒ nil

Load marshal data.



151
152
153
154
155
156
157
158
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 151

def marshal_load(obj)
  @params = obj[:params]
  @estimators = obj[:estimators]
  @classes = obj[:classes]
  @feature_importances = obj[:feature_importances]
  @rng = obj[:rng]
  nil
end

#predict(x) ⇒ Numo::Int32

Predict class labels for samples.



100
101
102
103
104
105
106
107
108
109
110
111
112
113
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 100

def predict(x)
  n_samples, = x.shape
  n_classes = @classes.size
  classes_arr = @classes.to_a
  ballot_box = Numo::DFloat.zeros(n_samples, n_classes)
  @estimators.each do |tree|
    predicted = tree.predict(x)
    n_samples.times do |n|
      class_id = classes_arr.index(predicted[n])
      ballot_box[n, class_id] += 1.0 unless class_id.nil?
    end
  end
  Numo::Int32[*Array.new(n_samples) { |n| @classes[ballot_box[n, true].max_index] }]
end

#predict_proba(x) ⇒ Numo::DFloat

Predict probability for samples.



119
120
121
122
123
124
125
126
127
128
129
130
131
132
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 119

def predict_proba(x)
  n_samples, = x.shape
  n_classes = @classes.size
  classes_arr = @classes.to_a
  ballot_box = Numo::DFloat.zeros(n_samples, n_classes)
  @estimators.each do |tree|
    probs = tree.predict_proba(x)
    tree.classes.size.times do |n|
      class_id = classes_arr.index(tree.classes[n])
      ballot_box[true, class_id] += probs[true, n] unless class_id.nil?
    end
  end
  (ballot_box.transpose / ballot_box.sum(axis: 1)).transpose
end