Class: ML::Learner::PerceptronLearner

Inherits:
Object
  • Object
show all
Defined in:
lib/method/perceptron.rb

Overview

Implementation of Perceptron Learning Algorithm

Direct Known Subclasses

AdaptivePerceptronLearner

Instance Method Summary collapse

Constructor Details

#initialize(dim, thres = 1.0/0) ⇒ PerceptronLearner

Initialize a perceptron learner

Parameters:

  • dim (Integer)

    the number of dimension



10
11
12
13
# File 'lib/method/perceptron.rb', line 10

def initialize dim, thres = 1.0/0
  @dim = dim
  @w = Matrix.column_vector(Array.new(dim + 1, 0))
end

Instance Method Details

#lineArray

The final coefficient of the line

Returns:

  • (Array)
    a,b,c

    for ax+by+c=0



60
61
62
# File 'lib/method/perceptron.rb', line 60

def line
  @w.column(0).to_a
end

#predict(data) ⇒ Integer

Predict certain data

Parameters:

  • data (Array)

    data in question

Returns:

  • (Integer)

    prediction



68
69
70
# File 'lib/method/perceptron.rb', line 68

def predict data
  classify(Matrix.column_vector(data + [1.0])) <=> 0
end

#train!(data, threshold = 1.0/0) ⇒ Array

Train with supervised data

Parameters:

  • data (Hash)

    supervised input data (mapping from array to integer)

  • threshold (Numeric) (defaults to: 1.0/0)

    the upper bound of the traning iteration

Returns:

  • (Array)

    error_and_update [error, update] error in traning and update numbers used



20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# File 'lib/method/perceptron.rb', line 20

def train! data, threshold = 1.0/0
  pool = data.to_a
  update = 0
  error = 0

  while true
    break if update >= threshold
    misclassified = false
    order = (1...(pool.size)).to_a.shuffle

    for i in order
      dat, result = pool[i]
      aug_data = Matrix.column_vector(dat)

      if wrongly_classify aug_data, result
        misclassified = true

        update_vector aug_data, result
        update += 1
        break
      end
    end

    break unless misclassified
  end

  # check out errors
  if update >= threshold
    for dat, result in pool
      classified_result = (classify(Matrix.column_vector(dat)) <=> 0)
      error += 1 unless result == classified_result
    end
  end

  [error, update]
end