Class: SVMKit::Tree::DecisionTreeClassifier
- Inherits:
-
Object
- Object
- SVMKit::Tree::DecisionTreeClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/tree/decision_tree_classifier.rb
Overview
DecisionTreeClassifier is a class that implements decision tree for classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#leaf_labels ⇒ Numo::Int32
readonly
Return the labels assigned each leaf.
-
#rng ⇒ Random
readonly
Return the random generator for performing random sampling in the Pegasos algorithm.
-
#tree ⇒ Node
readonly
Return the learned tree.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
-
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
constructor
Create a new classifier with decision tree algorithm.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
Create a new classifier with decision tree algorithm.
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 119 def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, max_features: max_features, random_seed: random_seed) SVMKit::Validation.check_params_integer(min_samples_leaf: min_samples_leaf) SVMKit::Validation.check_params_string(criterion: criterion) SVMKit::Validation.check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features) @params = {} @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @criterion = :gini @criterion = :entropy if @params[:criterion] == 'entropy' @tree = nil @classes = nil @feature_importances = nil @n_leaves = nil @leaf_labels = nil @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
89 90 91 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 89 def classes @classes end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
93 94 95 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 93 def feature_importances @feature_importances end |
#leaf_labels ⇒ Numo::Int32 (readonly)
Return the labels assigned each leaf.
105 106 107 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 105 def leaf_labels @leaf_labels end |
#rng ⇒ Random (readonly)
Return the random generator for performing random sampling in the Pegasos algorithm.
101 102 103 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 101 def rng @rng end |
#tree ⇒ Node (readonly)
Return the learned tree.
97 98 99 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 97 def tree @tree end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
186 187 188 189 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 186 def apply(x) SVMKit::Validation.check_sample_array(x) Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })] end |
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
150 151 152 153 154 155 156 157 158 159 160 161 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 150 def fit(x, y) SVMKit::Validation.check_sample_array(x) SVMKit::Validation.check_label_array(y) SVMKit::Validation.check_sample_label_size(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) build_tree(x, y) eval_importance(n_samples, n_features) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
193 194 195 196 197 198 199 200 201 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 193 def marshal_dump { params: @params, classes: @classes, criterion: @criterion, tree: @tree, feature_importances: @feature_importances, leaf_labels: @leaf_labels, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
205 206 207 208 209 210 211 212 213 214 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 205 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @criterion = obj[:criterion] @tree = obj[:tree] @feature_importances = obj[:feature_importances] @leaf_labels = obj[:leaf_labels] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
167 168 169 170 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 167 def predict(x) SVMKit::Validation.check_sample_array(x) @leaf_labels[apply(x)] end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
176 177 178 179 180 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 176 def predict_proba(x) SVMKit::Validation.check_sample_array(x) probs = Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })] probs[true, @classes] end |