Class: SVMKit::Tree::DecisionTreeClassifier
- Inherits:
-
Object
- Object
- SVMKit::Tree::DecisionTreeClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/tree/decision_tree_classifier.rb
Overview
DecisionTreeClassifier is a class that implements decision tree for classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#leaf_labels ⇒ Numo::Int32
readonly
Return the labels assigned each leaf.
-
#rng ⇒ Random
readonly
Return the random generator for random selection of feature index.
-
#tree ⇒ Node
readonly
Return the learned tree.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
-
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
constructor
Create a new classifier with decision tree algorithm.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
Create a new classifier with decision tree algorithm.
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 56 def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, max_features: max_features, random_seed: random_seed) SVMKit::Validation.check_params_integer(min_samples_leaf: min_samples_leaf) SVMKit::Validation.check_params_string(criterion: criterion) SVMKit::Validation.check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features) @params = {} @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @criterion = :gini @criterion = :entropy if @params[:criterion] == 'entropy' @tree = nil @classes = nil @feature_importances = nil @n_leaves = nil @leaf_labels = nil @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
26 27 28 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 26 def classes @classes end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
30 31 32 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 30 def feature_importances @feature_importances end |
#leaf_labels ⇒ Numo::Int32 (readonly)
Return the labels assigned each leaf.
42 43 44 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 42 def leaf_labels @leaf_labels end |
#rng ⇒ Random (readonly)
Return the random generator for random selection of feature index.
38 39 40 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 38 def rng @rng end |
#tree ⇒ Node (readonly)
Return the learned tree.
34 35 36 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 34 def tree @tree end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
123 124 125 126 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 123 def apply(x) SVMKit::Validation.check_sample_array(x) Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })] end |
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
87 88 89 90 91 92 93 94 95 96 97 98 99 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 87 def fit(x, y) SVMKit::Validation.check_sample_array(x) SVMKit::Validation.check_label_array(y) SVMKit::Validation.check_sample_label_size(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min uniq_y = y.to_a.uniq.sort @classes = Numo::Int32.asarray(uniq_y) build_tree(x, y.map { |v| uniq_y.index(v) }) eval_importance(n_samples, n_features) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
130 131 132 133 134 135 136 137 138 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 130 def marshal_dump { params: @params, classes: @classes, criterion: @criterion, tree: @tree, feature_importances: @feature_importances, leaf_labels: @leaf_labels, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
142 143 144 145 146 147 148 149 150 151 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 142 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @criterion = obj[:criterion] @tree = obj[:tree] @feature_importances = obj[:feature_importances] @leaf_labels = obj[:leaf_labels] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
105 106 107 108 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 105 def predict(x) SVMKit::Validation.check_sample_array(x) @leaf_labels[apply(x)] end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
114 115 116 117 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 114 def predict_proba(x) SVMKit::Validation.check_sample_array(x) Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })] end |