Class: MachineLearningWorkbench::Compressor::VectorQuantization

Inherits:
Object
  • Object
show all
Defined in:
lib/machine_learning_workbench/compressor/vector_quantization.rb

Overview

Standard Vector Quantization

Direct Known Subclasses

OnlineVectorQuantization

Constant Summary collapse

Verification =
MachineLearningWorkbench::Tools::Verification
SIMIL =
{
  dot: -> (centr, vec) { centr.dot(vec) },
  mse: -> (centr, vec) { -((centr-vec)**2).sum / centr.size }
}

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(ncentrs:, dims:, vrange:, lrate:, simil_type: nil, init_centr_vrange: nil, rseed: Random.new_seed) ⇒ VectorQuantization

Returns a new instance of VectorQuantization.



8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 8

def initialize ncentrs:, dims:, vrange:, lrate:, simil_type: nil, init_centr_vrange: nil, rseed: Random.new_seed
  # TODO: RNG CURRENTLY NOT USED!!
  @rng = Random.new rseed
  @ncentrs = ncentrs
  @dims = Array(dims)
  check_lrate lrate # hack: so that we can overload it in online_vq
  @lrate = lrate
  @simil_type = simil_type || :dot
  @init_centr_vrange ||= vrange
  @vrange = case vrange
    when Array
      raise ArgumentError, "vrange size not 2: #{vrange}" unless vrange.size == 2
      vrange.map &method(:Float)
    when Range
      [vrange.first, vrange.last].map &method(:Float)
    else raise ArgumentError, "vrange: unrecognized type: #{vrange.class}"
  end
  init_centrs
  @ntrains = [0]*ncentrs # useful to understand what happens
end

Instance Attribute Details

#centrsObject (readonly)

Returns the value of attribute centrs.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def centrs
  @centrs
end

#dimsObject (readonly)

Returns the value of attribute dims.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def dims
  @dims
end

#init_centr_vrangeObject (readonly)

Returns the value of attribute init_centr_vrange.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def init_centr_vrange
  @init_centr_vrange
end

#lrateObject (readonly)

Returns the value of attribute lrate.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def lrate
  @lrate
end

#ncentrsObject (readonly)

Returns the value of attribute ncentrs.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def ncentrs
  @ncentrs
end

#ntrainsObject (readonly)

Returns the value of attribute ntrains.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def ntrains
  @ntrains
end

#rngObject (readonly)

Returns the value of attribute rng.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def rng
  @rng
end

#simil_typeObject (readonly)

Returns the value of attribute simil_type.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def simil_type
  @simil_type
end

#vrangeObject (readonly)

Returns the value of attribute vrange.



5
6
7
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 5

def vrange
  @vrange
end

Instance Method Details

#check_lrate(lrate) ⇒ Object

Verify lrate to be present and withing unit bounds As a separate method only so it can be overloaded in ‘OnlineVectorQuantization`

Raises:

  • (ArgumentError)


31
32
33
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 31

def check_lrate lrate
  raise ArgumentError, "Pass a `lrate` between 0 and 1" unless lrate&.between?(0,1)
end

#encode(vec, type: :most_similar) ⇒ Object

Encode a vector TODO: optimize for Numo



67
68
69
70
71
72
73
74
75
76
77
78
79
80
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 67

def encode vec, type: :most_similar
  simils = similarities vec
  case type
  when :most_similar
    simils.index simils.max
  when :ensemble
    simils
  when :ensemble_norm
    tot = simils.reduce(:+)
    tot = 1 if tot == 0  # HACK: avoid division by zero
    simils.map { |s| s/tot }
  else raise ArgumentError, "unrecognized encode type: #{type}"
  end
end

#init_centrs(nc: ncentrs, base: nil, proport: nil) ⇒ Object

Initializes a list of centroids



36
37
38
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 36

def init_centrs nc: ncentrs, base: nil, proport: nil
  @centrs = nc.times.map { new_centr base, proport }
end

#most_similar_centr(vec) ⇒ Array<Integer, Float>

Returns index and similitude of most similar centroid to vector

Returns:

  • (Array<Integer, Float>)

    the index of the most similar centroid, followed by the corresponding similarity



99
100
101
102
103
104
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 99

def most_similar_centr vec
  simils = similarities vec
  max_simil = simils.max
  max_idx = simils.index max_simil
  [max_idx, max_simil]
end

#new_centr(base = nil, proport = nil) ⇒ Object

Creates a new (random) centroid If a base is passed, this is meshed with the random centroid. This is done to facilitate distributing the training across centroids. TODO: USE RNG HERE!!

Raises:

  • (ArgumentError)


44
45
46
47
48
49
50
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 44

def new_centr base=nil, proport=nil
  raise ArgumentError, "Either both or none" if base.nil? ^ proport.nil?
  # require 'pry'; binding.pry if base.nil? ^ proport.nil?
  ret = NArray.new(*dims).rand(*init_centr_vrange)
  ret = ret * (1-proport) + base * proport if base&&proport
  ret
end

#reconstr_error(vec, code: nil, type: :most_similar) ⇒ NArray

Per-pixel errors in reconstructing vector

Returns:



108
109
110
111
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 108

def reconstr_error vec, code: nil, type: :most_similar
  code ||= encode vec, type: type
  (vec - reconstruction(code, type: type)).abs.sum
end

#reconstruction(code, type: :most_similar) ⇒ Object

Reconstruct vector from its code (encoding)



83
84
85
86
87
88
89
90
91
92
93
94
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 83

def reconstruction code, type: :most_similar
  case type
  when :most_similar
    centrs[code]
  when :ensemble
    tot = code.reduce :+
    centrs.zip(code).map { |centr, contr| centr*contr/tot }.reduce :+
  when :ensemble_norm
    centrs.zip(code).map { |centr, contr| centr*contr }.reduce :+
  else raise ArgumentError, "unrecognized reconstruction type: #{type}"
  end
end

#similarities(vec, type: simil_type) ⇒ Object

Computes similarities between vector and all centroids

Raises:

  • (NotImplementedError)


58
59
60
61
62
63
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 58

def similarities vec, type: simil_type
  raise NotImplementedError if vec.shape.size > 1
  centrs.map { |centr| SIMIL[type].call centr, vec }
  # require 'parallel'
  # Parallel.map(centrs) { |c| c.dot(vec).first }
end

#train(vec_lst, debug: false) ⇒ Object

Train on vector list



124
125
126
127
128
129
130
131
132
133
134
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 124

def train vec_lst, debug: false
  # Two ways here:
  # - Batch: canonical, centrs updated with each vec
  # - Parallel: could be parallel either on simils or on training (?)
  # Unsure on the correctness of either Parallel, let's stick with Batch
  vec_lst.each_with_index do |vec, i|
    trained_idx = train_one vec
    print '.' if debug
    ntrains[trained_idx] += 1
  end
end

#train_one(vec) ⇒ Integer

Train on one vector

Returns:

  • (Integer)

    index of trained centroid



115
116
117
118
119
120
121
# File 'lib/machine_learning_workbench/compressor/vector_quantization.rb', line 115

def train_one vec
  trg_idx, _simil = most_similar_centr(vec)
  # note: uhm that actually looks like a dot product... maybe faster?
  #   `[c[i], vec].dot([1-lrate, lrate])`
  centrs[trg_idx] = centrs[trg_idx] * (1-lrate) + vec * lrate
  trg_idx
end