Class: SVMKit::Multiclass::OneVsRestClassifier
- Inherits:
-
Object
- Object
- SVMKit::Multiclass::OneVsRestClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/multiclass/one_vs_rest_classifier.rb
Overview
OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-label classification.
Instance Attribute Summary collapse
-
#classes ⇒ Numo::Int32
readonly
Return the class labels.
-
#estimators ⇒ Array<Classifier>
readonly
Return the set of estimators.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ OneVsRestClassifier
Fit the model with given training data.
-
#initialize(estimator: nil) ⇒ OneVsRestClassifier
constructor
Create a new multi-label classifier with the one-vs-rest startegy.
-
#marshal_dump ⇒ Hash
Dump marshal data.
-
#marshal_load(obj) ⇒ nil
Load marshal data.
-
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
-
#score(x, y) ⇒ Float
Claculate the mean accuracy of the given testing data.
Constructor Details
#initialize(estimator: nil) ⇒ OneVsRestClassifier
Create a new multi-label classifier with the one-vs-rest startegy.
30 31 32 33 34 35 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 30 def initialize(estimator: nil) @params = {} @params[:estimator] = estimator @estimators = nil @classes = nil end |
Instance Attribute Details
#classes ⇒ Numo::Int32 (readonly)
Return the class labels.
25 26 27 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 25 def classes @classes end |
#estimators ⇒ Array<Classifier> (readonly)
Return the set of estimators.
21 22 23 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 21 def estimators @estimators end |
Instance Method Details
#decision_function(x) ⇒ Numo::DFloat
Calculate confidence scores for samples.
56 57 58 59 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 56 def decision_function(x) n_classes = @classes.size Numo::DFloat.asarray(Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }).transpose end |
#fit(x, y) ⇒ OneVsRestClassifier
Fit the model with given training data.
42 43 44 45 46 47 48 49 50 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 42 def fit(x, y) y_arr = y.to_a @classes = Numo::Int32.asarray(y_arr.uniq.sort) @estimators = @classes.to_a.map do |label| bin_y = Numo::Int32.asarray(y_arr.map { |l| l == label ? 1 : -1 }) @params[:estimator].dup.fit(x, bin_y) end self end |
#marshal_dump ⇒ Hash
Dump marshal data.
84 85 86 87 88 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 84 def marshal_dump { params: @params, classes: @classes, estimators: @estimators.map { |e| Marshal.dump(e) } } end |
#marshal_load(obj) ⇒ nil
Load marshal data.
92 93 94 95 96 97 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 92 def marshal_load(obj) @params = obj[:params] @classes = obj[:classes] @estimators = obj[:estimators].map { |e| Marshal.load(e) } nil end |
#predict(x) ⇒ Numo::Int32
Predict class labels for samples.
65 66 67 68 69 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 65 def predict(x) n_samples, = x.shape decision_values = decision_function(x) Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] }) end |
#score(x, y) ⇒ Float
Claculate the mean accuracy of the given testing data.
76 77 78 79 80 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 76 def score(x, y) p = predict(x) n_hits = (y.to_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+) n_hits / y.size.to_f end |