Class: SVMKit::Multiclass::OneVsRestClassifier

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator, Base::Classifier
Defined in:
lib/svmkit/multiclass/one_vs_rest_classifier.rb

Overview

OneVsRestClassifier is a class that implements One-vs-Rest (OvR) strategy for multi-label classification.

base_estimator =
 SVMKit::LinearModel::PegasosSVC.new(penalty: 1.0, max_iter: 100, batch_size: 20, random_seed: 1)
estimator = SVMKit::Multiclass::OneVsRestClassifier.new(estimator: base_estimator)
estimator.fit(training_samples, training_labels)
results = estimator.predict(testing_samples)

Constant Summary collapse

DEFAULT_PARAMS =

:nodoc:

{ # :nodoc:
  estimator: nil
}.freeze

Instance Attribute Summary collapse

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Constructor Details

#initialize(params = {}) ⇒ OneVsRestClassifier

Create a new multi-label classifier with the one-vs-rest startegy.

:call-seq:

new(estimator: base_estimator) -> OneVsRestClassifier
  • Arguments :

    • :estimator (Classifier) (defaults to: nil) – The (binary) classifier for construction a multi-label classifier.



36
37
38
39
40
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 36

def initialize(params = {})
  self.params = DEFAULT_PARAMS.merge(Hash[params.map { |k, v| [k.to_sym, v] }])
  @estimators = nil
  @classes = nil
end

Instance Attribute Details

#classesObject (readonly)

The class labels.



27
28
29
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 27

def classes
  @classes
end

#estimatorsObject (readonly)

The set of estimators.



24
25
26
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 24

def estimators
  @estimators
end

Instance Method Details

#decision_function(x) ⇒ Object

Calculate confidence scores for samples.

:call-seq:

decision_function(x) -> NMatrix, shape: [n_samples, n_classes]
  • Arguments :

    • x (NMatrix, shape: [n_samples, n_features]) – The samples to compute the scores.

  • Returns :

    • Confidence scores per sample for each class.



70
71
72
73
74
75
76
77
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 70

def decision_function(x)
  n_samples, = x.shape
  n_classes = @classes.size
  NMatrix.new(
    [n_classes, n_samples],
    Array.new(n_classes) { |m| @estimators[m].decision_function(x).to_a }.flatten
  ).transpose
end

#fit(x, y) ⇒ Object

Fit the model with given training data.

:call-seq:

fit(x, y) -> OneVsRestClassifier
  • Arguments :

    • x (NMatrix, shape: [n_samples, n_features]) – The training data to be used for fitting the model.

    • y (NMatrix, shape: [1, n_samples]) – The labels to be used for fitting the model.

  • Returns :

    • The learned classifier itself.



52
53
54
55
56
57
58
59
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 52

def fit(x, y)
  @classes = y.uniq.sort
  @estimators = @classes.map do |label|
    bin_y = y.map { |l| l == label ? 1 : -1 }
    params[:estimator].dup.fit(x, bin_y)
  end
  self
end

#marshal_dumpObject

Serializes object through Marshal#dump.



112
113
114
115
116
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 112

def marshal_dump # :nodoc:
  { params: params,
    classes: @classes,
    estimators: @estimators.map { |e| Marshal.dump(e) } }
end

#marshal_load(obj) ⇒ Object

Deserialize object through Marshal#load.



119
120
121
122
123
124
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 119

def marshal_load(obj) # :nodoc:
  self.params = obj[:params]
  @classes = obj[:classes]
  @estimators = obj[:estimators].map { |e| Marshal.load(e) }
  nil
end

#predict(x) ⇒ Object

Predict class labels for samples.

:call-seq:

predict(x) -> NMatrix, shape: [1, n_samples]
  • Arguments :

    • x (NMatrix, shape: [n_samples, n_features]) – The samples to predict the labels.

  • Returns :

    • Predicted class label per sample.



88
89
90
91
92
93
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 88

def predict(x)
  n_samples, = x.shape
  decision_values = decision_function(x)
  NMatrix.new([1, n_samples],
              decision_values.each_row.map { |vals| @classes[vals.to_a.index(vals.to_a.max)] })
end

#score(x, y) ⇒ Object

Claculate the mean accuracy of the given testing data.

:call-seq:

predict(x, y) -> Float
  • Arguments :

    • x (NMatrix, shape: [n_samples, n_features]) – Testing data.

    • y (NMatrix, shape: [1, n_samples]) – True labels for testing data.

  • Returns :

    • Mean accuracy



105
106
107
108
109
# File 'lib/svmkit/multiclass/one_vs_rest_classifier.rb', line 105

def score(x, y)
  p = predict(x)
  n_hits = (y.to_flat_a.map.with_index { |l, n| l == p[n] ? 1 : 0 }).inject(:+)
  n_hits / y.size.to_f
end