Class: Rumale::NearestNeighbors::KNeighborsClassifier

Inherits:
Object
  • Object
show all
Includes:
Base::BaseEstimator, Base::Classifier
Defined in:
lib/rumale/nearest_neighbors/k_neighbors_classifier.rb

Overview

KNeighborsClassifier is a class that implements the classifier with the k-nearest neighbors rule. The current implementation uses the Euclidean distance for finding the neighbors.

Examples:

estimator =
  Rumale::NearestNeighbors::KNeighborsClassifier.new(n_neighbors: 5)
estimator.fit(training_samples, traininig_labels)
results = estimator.predict(testing_samples)

Instance Attribute Summary collapse

Attributes included from Base::BaseEstimator

#params

Instance Method Summary collapse

Methods included from Base::Classifier

#score

Constructor Details

#initialize(n_neighbors: 5) ⇒ KNeighborsClassifier

Create a new classifier with the nearest neighbor rule.

Parameters:

  • n_neighbors (Integer) (defaults to: 5)

    The number of neighbors.



37
38
39
40
41
42
43
44
45
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 37

def initialize(n_neighbors: 5)
  check_params_integer(n_neighbors: n_neighbors)
  check_params_positive(n_neighbors: n_neighbors)
  @params = {}
  @params[:n_neighbors] = n_neighbors
  @prototypes = nil
  @labels = nil
  @classes = nil
end

Instance Attribute Details

#classesNumo::Int32 (readonly)

Return the class labels.

Returns:

  • (Numo::Int32)

    (size: n_classes)



32
33
34
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 32

def classes
  @classes
end

#labelsNumo::Int32 (readonly)

Return the labels of the prototypes

Returns:

  • (Numo::Int32)

    (size: n_samples)



28
29
30
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 28

def labels
  @labels
end

#prototypesNumo::DFloat (readonly)

Return the prototypes for the nearest neighbor classifier.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_features])



24
25
26
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 24

def prototypes
  @prototypes
end

Instance Method Details

#decision_function(x) ⇒ Numo::DFloat

Calculate confidence scores for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to compute the scores.

Returns:

  • (Numo::DFloat)

    (shape: [n_samples, n_classes]) Confidence scores per sample for each class.



66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 66

def decision_function(x)
  check_sample_array(x)
  distance_matrix = PairwiseMetric.euclidean_distance(x, @prototypes)
  n_samples, n_prototypes = distance_matrix.shape
  n_classes = @classes.size
  n_neighbors = [@params[:n_neighbors], n_prototypes].min
  scores = Numo::DFloat.zeros(n_samples, n_classes)
  n_samples.times do |m|
    neighbor_ids = distance_matrix[m, true].to_a.each_with_index.sort.map(&:last)[0...n_neighbors]
    neighbor_ids.each { |n| scores[m, @classes.to_a.index(@labels[n])] += 1.0 }
  end
  scores
end

#fit(x, y) ⇒ KNeighborsClassifier

Fit the model with given training data.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The training data to be used for fitting the model.

  • y (Numo::Int32)

    (shape: [n_samples]) The labels to be used for fitting the model.

Returns:



52
53
54
55
56
57
58
59
60
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 52

def fit(x, y)
  check_sample_array(x)
  check_label_array(y)
  check_sample_label_size(x, y)
  @prototypes = Numo::DFloat.asarray(x.to_a)
  @labels = Numo::Int32.asarray(y.to_a)
  @classes = Numo::Int32.asarray(y.to_a.uniq.sort)
  self
end

#marshal_dumpHash

Dump marshal data.

Returns:

  • (Hash)

    The marshal data about KNeighborsClassifier.



93
94
95
96
97
98
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 93

def marshal_dump
  { params: @params,
    prototypes: @prototypes,
    labels: @labels,
    classes: @classes }
end

#marshal_load(obj) ⇒ nil

Load marshal data.

Returns:

  • (nil)


102
103
104
105
106
107
108
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 102

def marshal_load(obj)
  @params = obj[:params]
  @prototypes = obj[:prototypes]
  @labels = obj[:labels]
  @classes = obj[:classes]
  nil
end

#predict(x) ⇒ Numo::Int32

Predict class labels for samples.

Parameters:

  • x (Numo::DFloat)

    (shape: [n_samples, n_features]) The samples to predict the labels.

Returns:

  • (Numo::Int32)

    (shape: [n_samples]) Predicted class label per sample.



84
85
86
87
88
89
# File 'lib/rumale/nearest_neighbors/k_neighbors_classifier.rb', line 84

def predict(x)
  check_sample_array(x)
  n_samples = x.shape.first
  decision_values = decision_function(x)
  Numo::Int32.asarray(Array.new(n_samples) { |n| @classes[decision_values[n, true].max_index] })
end