Class: SVMKit::Ensemble::RandomForestClassifier
- Inherits:
-
Object
- Object
- SVMKit::Ensemble::RandomForestClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/ensemble/random_forest_classifier.rb
Overview
RandomForestClassifier is a class that implements random forest for classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<DecisionTreeClassifier>
readonly
Return the set of estimators.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#rng ⇒ Random
readonly
Return the random generator for performing random sampling in the Pegasos algorithm.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ RandomForestClassifier
Fit the model with given training data.
-
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ RandomForestClassifier
constructor
Create a new classifier with random forest.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ RandomForestClassifier
Create a new classifier with random forest.
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 51 def initialize(n_estimators: 10, criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, max_features: max_features, random_seed: random_seed) SVMKit::Validation.check_params_integer(n_estimators: n_estimators, min_samples_leaf: min_samples_leaf) SVMKit::Validation.check_params_string(criterion: criterion) SVMKit::Validation.check_params_positive(n_estimators: n_estimators, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features) @params = {} @params[:n_estimators] = n_estimators @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @estimators = nil @classes = nil @feature_importances = nil @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
28 29 30 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 28 def classes @classes end |
#estimators ⇒ Array<DecisionTreeClassifier> (readonly)
Return the set of estimators.
24 25 26 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 24 def estimators @estimators end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
32 33 34 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 32 def feature_importances @feature_importances end |
#rng ⇒ Random (readonly)
Return the random generator for performing random sampling in the Pegasos algorithm.
36 37 38 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 36 def rng @rng end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
150 151 152 153 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 150 def apply(x) SVMKit::Validation.check_sample_array(x) Numo::Int32[*Array.new(@params[:n_estimators]) { |n| @estimators[n].apply(x) }].transpose end |
#fit(x, y) ⇒ RandomForestClassifier
Fit the model with given training data.
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 80 def fit(x, y) SVMKit::Validation.check_sample_array(x) SVMKit::Validation.check_label_array(y) SVMKit::Validation.check_sample_label_size(x, y) # Initialize some variables. n_samples, n_features = x.shape @params[:max_features] = n_features unless @params[:max_features].is_a?(Integer) @params[:max_features] = [[1, @params[:max_features]].max, Math.sqrt(n_features).to_i].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) # Construct forest. @estimators = Array.new(@params[:n_estimators]) do |_n| tree = Tree::DecisionTreeClassifier.new( criterion: @params[:criterion], max_depth: @params[:max_depth], max_leaf_nodes: @params[:max_leaf_nodes], min_samples_leaf: @params[:min_samples_leaf], max_features: @params[:max_features], random_seed: @params[:random_seed] ) bootstrap_ids = Array.new(n_samples) { @rng.rand(0...n_samples) } tree.fit(x[bootstrap_ids, true], y[bootstrap_ids]) end # Calculate feature importances. @feature_importances = Numo::DFloat.zeros(n_features) @estimators.each { |tree| @feature_importances += tree.feature_importances } @feature_importances /= @feature_importances.sum self end |
#marshal_dump ⇒ Hash
Dump marshal data.
157 158 159 160 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 157 def marshal_dump { params: @params, estimators: @estimators, classes: @classes, feature_importances: @feature_importances, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
164 165 166 167 168 169 170 171 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 164 def marshal_load(obj) @params = obj[:params] @estimators = obj[:estimators] @classes = obj[:classes] @feature_importances = obj[:feature_importances] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 110 def predict(x) SVMKit::Validation.check_sample_array(x) n_samples, = x.shape n_classes = @classes.size classes_arr = @classes.to_a ballot_box = Numo::DFloat.zeros(n_samples, n_classes) @estimators.each do |tree| predicted = tree.predict(x) n_samples.times do |n| class_id = classes_arr.index(predicted[n]) ballot_box[n, class_id] += 1.0 unless class_id.nil? end end Numo::Int32[*Array.new(n_samples) { |n| @classes[ballot_box[n, true].max_index] }] end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# File 'lib/svmkit/ensemble/random_forest_classifier.rb', line 130 def predict_proba(x) SVMKit::Validation.check_sample_array(x) n_samples, = x.shape n_classes = @classes.size classes_arr = @classes.to_a ballot_box = Numo::DFloat.zeros(n_samples, n_classes) @estimators.each do |tree| probs = tree.predict_proba(x) tree.classes.size.times do |n| class_id = classes_arr.index(tree.classes[n]) ballot_box[true, class_id] += probs[true, n] unless class_id.nil? end end (ballot_box.transpose / ballot_box.sum(axis: 1)).transpose end |