Class: SVMKit::Tree::DecisionTreeClassifier
- Inherits:
-
Object
- Object
- SVMKit::Tree::DecisionTreeClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/tree/decision_tree_classifier.rb
Overview
DecisionTreeClassifier is a class that implements decision tree for classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#feature_importances ⇒ Numo::DFloat
readonly
Return the importance for each feature.
-
#leaf_labels ⇒ Numo::Int32
readonly
Return the labels assigned each leaf.
-
#rng ⇒ Random
readonly
Return the random generator for random selection of feature index.
-
#tree ⇒ Node
readonly
Return the learned tree.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
-
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
-
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
constructor
Create a new classifier with decision tree algorithm.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
Methods included from Base::Classifier
Constructor Details
#initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) ⇒ DecisionTreeClassifier
Create a new classifier with decision tree algorithm.
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 56 def initialize(criterion: 'gini', max_depth: nil, max_leaf_nodes: nil, min_samples_leaf: 1, max_features: nil, random_seed: nil) SVMKit::Validation.check_params_type_or_nil(Integer, max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, max_features: max_features, random_seed: random_seed) SVMKit::Validation.check_params_integer(min_samples_leaf: min_samples_leaf) SVMKit::Validation.check_params_string(criterion: criterion) SVMKit::Validation.check_params_positive(max_depth: max_depth, max_leaf_nodes: max_leaf_nodes, min_samples_leaf: min_samples_leaf, max_features: max_features) @params = {} @params[:criterion] = criterion @params[:max_depth] = max_depth @params[:max_leaf_nodes] = max_leaf_nodes @params[:min_samples_leaf] = min_samples_leaf @params[:max_features] = max_features @params[:random_seed] = random_seed @params[:random_seed] ||= srand @criterion = :gini @criterion = :entropy if @params[:criterion] == 'entropy' @tree = nil @classes = nil @feature_importances = nil @n_leaves = nil @leaf_labels = nil @rng = Random.new(@params[:random_seed]) end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
26 27 28 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 26 def classes @classes end |
#feature_importances ⇒ Numo::DFloat (readonly)
Return the importance for each feature.
30 31 32 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 30 def feature_importances @feature_importances end |
#leaf_labels ⇒ Numo::Int32 (readonly)
Return the labels assigned each leaf.
42 43 44 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 42 def leaf_labels @leaf_labels end |
#rng ⇒ Random (readonly)
Return the random generator for random selection of feature index.
38 39 40 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 38 def rng @rng end |
#tree ⇒ Node (readonly)
Return the learned tree.
34 35 36 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 34 def tree @tree end |
Instance Method Details
#apply(x) ⇒ Numo::Int32
Return the index of the leaf that each sample reached.
122 123 124 125 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 122 def apply(x) SVMKit::Validation.check_sample_array(x) Numo::Int32[*(Array.new(x.shape[0]) { |n| apply_at_node(@tree, x[n, true]) })] end |
#fit(x, y) ⇒ DecisionTreeClassifier
Fit the model with given training data.
87 88 89 90 91 92 93 94 95 96 97 98 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 87 def fit(x, y) SVMKit::Validation.check_sample_array(x) SVMKit::Validation.check_label_array(y) SVMKit::Validation.check_sample_label_size(x, y) n_samples, n_features = x.shape @params[:max_features] = n_features if @params[:max_features].nil? @params[:max_features] = [@params[:max_features], n_features].min @classes = Numo::Int32.asarray(y.to_a.uniq.sort) build_tree(x, y) eval_importance(n_samples, n_features) self end |
#marshal_dump ⇒ Hash
Dump marshal data.
129 130 131 132 133 134 135 136 137 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 129 def marshal_dump { params: @params, classes: @classes, criterion: @criterion, tree: @tree, feature_importances: @feature_importances, leaf_labels: @leaf_labels, rng: @rng } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
141 142 143 144 145 146 147 148 149 150 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 141 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @criterion = obj[:criterion] @tree = obj[:tree] @feature_importances = obj[:feature_importances] @leaf_labels = obj[:leaf_labels] @rng = obj[:rng] nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
104 105 106 107 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 104 def predict(x) SVMKit::Validation.check_sample_array(x) @leaf_labels[apply(x)] end |
#predict_proba(x) ⇒ Numo::DFloat
Predict probability for samples.
113 114 115 116 |
# File 'lib/svmkit/tree/decision_tree_classifier.rb', line 113 def predict_proba(x) SVMKit::Validation.check_sample_array(x) Numo::DFloat[*(Array.new(x.shape[0]) { |n| predict_at_node(@tree, x[n, true]) })] end |