Class: DNN::Layers::Embedding

Inherits:
TrainableLayer show all
Defined in:
lib/dnn/core/layers/embedding.rb

Instance Attribute Summary collapse

Attributes inherited from TrainableLayer

#trainable

Attributes inherited from Layer

#input_shape

Instance Method Summary collapse

Methods inherited from TrainableLayer

#clean

Methods inherited from Layer

#built?, call, #clean, from_hash, #output_shape

Constructor Details

#initialize(input_dim_or_shape, input_length, weight_initializer: Initializers::RandomUniform.new, weight_regularizer: nil) ⇒ Embedding

Returns a new instance of Embedding.

Parameters:

  • input_dim_or_shape (Integer | Array)

    Set input data dimension or shape.

  • input_length (Integer)

    Set the time series length of input data.

  • weight_initializer (DNN::Initializers::Initializer) (defaults to: Initializers::RandomUniform.new)

    Weight initializer.

  • weight_regularizer (DNN::Regularizers::Regularizer | NilClass) (defaults to: nil)

    Weight regularizer.



14
15
16
17
18
19
20
21
22
23
# File 'lib/dnn/core/layers/embedding.rb', line 14

def initialize(input_dim_or_shape, input_length,
               weight_initializer: Initializers::RandomUniform.new,
               weight_regularizer: nil)
  super()
  @input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
  @input_length = input_length
  @weight_initializer = weight_initializer
  @weight_regularizer = weight_regularizer
  @weight = Param.new(nil, Xumo::SFloat[0])
end

Instance Attribute Details

#input_lengthObject (readonly)

Returns the value of attribute input_length.



5
6
7
# File 'lib/dnn/core/layers/embedding.rb', line 5

def input_length
  @input_length
end

#weightObject (readonly)

Returns the value of attribute weight.



6
7
8
# File 'lib/dnn/core/layers/embedding.rb', line 6

def weight
  @weight
end

#weight_initializerObject (readonly)

Returns the value of attribute weight_initializer.



7
8
9
# File 'lib/dnn/core/layers/embedding.rb', line 7

def weight_initializer
  @weight_initializer
end

#weight_regularizerObject (readonly)

Returns the value of attribute weight_regularizer.



8
9
10
# File 'lib/dnn/core/layers/embedding.rb', line 8

def weight_regularizer
  @weight_regularizer
end

Instance Method Details

#<<(layer) ⇒ Object



71
72
73
74
75
76
# File 'lib/dnn/core/layers/embedding.rb', line 71

def <<(layer)
  if RUBY_VERSION < "2.6.0"
    raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
  end
  to_proc << layer
end

#>>(layer) ⇒ Object



64
65
66
67
68
69
# File 'lib/dnn/core/layers/embedding.rb', line 64

def >>(layer)
  if RUBY_VERSION < "2.6.0"
    raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
  end
  to_proc >> layer
end

#backward(dy) ⇒ Object



46
47
48
49
50
51
52
53
54
# File 'lib/dnn/core/layers/embedding.rb', line 46

def backward(dy)
  @weight.grad += Xumo::SFloat.zeros(*@weight.data.shape)
  @x.shape[0].times do |i|
    @x.shape[1].times do |j|
      @weight.grad[@x[i, j]] += dy[i, j]
    end
  end
  nil
end

#build(input_shape) ⇒ Object



30
31
32
33
34
35
# File 'lib/dnn/core/layers/embedding.rb', line 30

def build(input_shape)
  @built = true
  @weight.data = Xumo::SFloat.new(@input_length)
  @weight_initializer.init_param(self, @weight)
  @weight_regularizer.param = @weight if @weight_regularizer
end

#call(input_tensor) ⇒ Object



25
26
27
28
# File 'lib/dnn/core/layers/embedding.rb', line 25

def call(input_tensor)
  build(@input_shape) unless built?
  Tensor.new(forward(input_tensor.data), Link.new(nil, self))
end

#forward(x) ⇒ Object



37
38
39
40
41
42
43
44
# File 'lib/dnn/core/layers/embedding.rb', line 37

def forward(x)
  @x = x
  y = Xumo::SFloat.zeros(*x.shape)
  x.shape[0].times do |i|
    y[i, false] = @weight.data[x[i, false]]
  end
  y
end

#get_paramsObject



89
90
91
# File 'lib/dnn/core/layers/embedding.rb', line 89

def get_params
  { weight: @weight }
end

#load_hash(hash) ⇒ Object



83
84
85
86
87
# File 'lib/dnn/core/layers/embedding.rb', line 83

def load_hash(hash)
  initialize(hash[:input_shape], hash[:input_length],
             weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
             weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]))
end

#regularizersObject



56
57
58
# File 'lib/dnn/core/layers/embedding.rb', line 56

def regularizers
  @weight_regularizer ? [@weight_regularizer] : []
end

#to_hashObject



78
79
80
81
# File 'lib/dnn/core/layers/embedding.rb', line 78

def to_hash
  super(input_shape: @input_shape, input_length: @input_length,
        weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)
end

#to_procObject



60
61
62
# File 'lib/dnn/core/layers/embedding.rb', line 60

def to_proc
  method(:call).to_proc
end