Class: SVMKit::Multiclass::OneVsRestClassifier
- Inherits:
-
Object
- Object
- SVMKit::Multiclass::OneVsRestClassifier
- Includes:
- Base::BaseEstimator, Base::Classifier
- Defined in:
- lib/svmkit/multiclass/one_vs_rest_classifier.rb
Overview
OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-label classification.
base_estimator =
SVMKit::LinearModel::PegasosSVC.new(penalty: 1.0, max_iter: 100, batch_size: 20, random_seed: 1)
estimator = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_estimator)
estimator.fit(training_samples, training_labels)
results = estimator.predict(testing_samples)
Constant Summary collapse
- DEFAULT_PARAMS =
:nodoc:
{ # :nodoc: estimator: nil }.freeze
Instance Attribute Summary collapse
-
#classes ⇒ Object
readonly
The class labels.
-
#estimators ⇒ Object
readonly
The set of estimators.
Attributes included from Base::BaseEstimator
Instance Method Summary collapse
-
#decision_function(x) ⇒ Object
Calculate confidence scores for samples.
-
#fit(x, y) ⇒ Object
Fit the model with given training data.
-
#initialize(params = {}) ⇒ OneVsRestClassifier
constructor
Create a new multi-label classifier with the one-vs-rest startegy.
-
#marshal_dump ⇒ Object
Serializes object through Marshal#dump.
-
#marshal_load(obj) ⇒ Object
Deserialize object through Marshal#load.
-
#predict(x) ⇒ Object
Predict class labels for samples.
-
#score(x, y) ⇒ Object
Claculate the mean accuracy of the given testing data.
Constructor Details
#initialize(params = {}) ⇒ OneVsRestClassifier
Create a new multi-label classifier with the one-vs-rest startegy.
:call-seq:
new(estimator: base_estimator) -> OneVsRestClassifier
-
Arguments :
-
:estimator(Classifier) (defaults to: nil) – The (binary) classifier for construction a multi-label classifier.
-
36 37 38 39 40 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 36 def initialize(params = {}) self.params = DEFAULT_PARAMS.merge(Hash[params.map { |k, v| [k.to_sym, v] }]) @estimators = nil @classes = nil end |
Instance Attribute Details
#classes ⇒ Object (readonly)
The class labels.
27 28 29 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 27 def classes @classes end |
#estimators ⇒ Object (readonly)
The set of estimators.
24 25 26 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 24 def estimators @estimators end |
Instance Method Details
#decision_function(x) ⇒ Object
Calculate confidence scores for samples.
:call-seq:
decision_function(x) -> NMatrix, shape: [n_samples, n_classes]
-
Arguments :
-
x(NMatrix, shape: [n_samples, n_features]) – The samples to compute the scores.
-
-
Returns :
-
Confidence scores per sample for each class.
-
70 71 72 73 74 75 76 77 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 70 def decision_function(x) n_samples, = x.shape n_classes = @classes.size NMatrix.new( [n_classes, n_samples], Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }.flatten ).transpose end |
#fit(x, y) ⇒ Object
Fit the model with given training data.
:call-seq:
fit(x, y) -> OneVsRestClassifier
-
Arguments :
-
x(NMatrix, shape: [n_samples, n_features]) – The training data to be used for fitting the model. -
y(NMatrix, shape: [1, n_samples]) – The labels to be used for fitting the model.
-
-
Returns :
-
The learned classifier itself.
-
52 53 54 55 56 57 58 59 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 52 def fit(x, y) @classes = y.uniq.sort @estimators = @classes.map do |label| bin_y = y.map { |l| l == label ? 1 : -1 } params[:estimator].dup.fit(x, bin_y) end self end |
#marshal_dump ⇒ Object
Serializes object through Marshal#dump.
112 113 114 115 116 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 112 def marshal_dump # :nodoc: { params: params, classes: @classes, estimators: @estimators.map { |e| Marshal.dump(e) } } end |
#marshal_load(obj) ⇒ Object
Deserialize object through Marshal#load.
119 120 121 122 123 124 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 119 def marshal_load(obj) # :nodoc: self.params = obj[:params] @classes = obj[:classes] @estimators = obj[:estimators].map { |e| Marshal.load(e) } nil end |
#predict(x) ⇒ Object
Predict class labels for samples.
:call-seq:
predict(x) -> NMatrix, shape: [1, n_samples]
-
Arguments :
-
x(NMatrix, shape: [n_samples, n_features]) – The samples to predict the labels.
-
-
Returns :
-
Predicted class label per sample.
-
88 89 90 91 92 93 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 88 def predict(x) n_samples, = x.shape decision_values = decision_function(x) NMatrix.new([1, n_samples], decision_values.each_row.map { |vals| @classes[vals.to_a.index(vals.to_a.max)] }) end |
#score(x, y) ⇒ Object
Claculate the mean accuracy of the given testing data.
:call-seq:
predict(x, y) -> Float
-
Arguments :
-
x(NMatrix, shape: [n_samples, n_features]) – Testing data. -
y(NMatrix, shape: [1, n_samples]) – True labels for testing data.
-
-
Returns :
-
Mean accuracy
-
105 106 107 108 109 |
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 105 def score(x, y) p = predict(x) n_hits = (y.to_flat_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+) n_hits / y.size.to_f end |