Class: DNN::Layers::Embedding

Inherits:
HasParamLayer show all
Defined in:
lib/dnn/core/embedding.rb

Instance Attribute Summary collapse

Attributes inherited from HasParamLayer

#trainable

Attributes inherited from Layer

#input_shape, #name

Instance Method Summary collapse

Methods inherited from Layer

#built?, call, from_hash, #output_shape

Constructor Details

#initialize(input_dim_or_shape, input_length, weight_initializer: Initializers::RandomUniform.new, weight_regularizer: nil) ⇒ Embedding

Returns a new instance of Embedding.

Parameters:

  • input_dim_or_shape (Integer | Array)

    Set input data dimension or shape.

  • input_length (Integer)

    Set the time series length of input data.

  • weight_initializer (DNN::Initializers::Initializer) (defaults to: Initializers::RandomUniform.new)

    Weight initializer.

  • weight_regularizer (DNN::Regularizers::Regularizer | NilClass) (defaults to: nil)

    Weight regularizer.



14
15
16
17
18
19
20
21
22
# File 'lib/dnn/core/embedding.rb', line 14

def initialize(input_dim_or_shape, input_length,
               weight_initializer: Initializers::RandomUniform.new,
               weight_regularizer: nil)
  super()
  @input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
  @input_length = input_length
  @weight_initializer = weight_initializer
  @weight_regularizer = weight_regularizer
end

Instance Attribute Details

#input_lengthObject (readonly)

Returns the value of attribute input_length.



5
6
7
# File 'lib/dnn/core/embedding.rb', line 5

def input_length
  @input_length
end

#weightObject (readonly)

Returns the value of attribute weight.



6
7
8
# File 'lib/dnn/core/embedding.rb', line 6

def weight
  @weight
end

#weight_initializerObject (readonly)

Returns the value of attribute weight_initializer.



7
8
9
# File 'lib/dnn/core/embedding.rb', line 7

def weight_initializer
  @weight_initializer
end

#weight_regularizerObject (readonly)

Returns the value of attribute weight_regularizer.



8
9
10
# File 'lib/dnn/core/embedding.rb', line 8

def weight_regularizer
  @weight_regularizer
end

Instance Method Details

#<<(layer) ⇒ Object



70
71
72
73
74
75
# File 'lib/dnn/core/embedding.rb', line 70

def <<(layer)
  if RUBY_VERSION < "2.6.0"
    raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
  end
  to_proc << layer
end

#>>(layer) ⇒ Object



63
64
65
66
67
68
# File 'lib/dnn/core/embedding.rb', line 63

def >>(layer)
  if RUBY_VERSION < "2.6.0"
    raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
  end
  to_proc >> layer
end

#backward(dy) ⇒ Object



45
46
47
48
49
50
51
52
53
# File 'lib/dnn/core/embedding.rb', line 45

def backward(dy)
  @weight.grad += Xumo::SFloat.zeros(*@weight.data.shape)
  @x.shape[0].times do |i|
    @x.shape[1].times do |j|
      @weight.grad[@x[i, j]] += dy[i, j]
    end
  end
  nil
end

#buildObject



29
30
31
32
33
34
# File 'lib/dnn/core/embedding.rb', line 29

def build
  @built = true
  @weight = Param.new(Xumo::SFloat.new(@input_length), Xumo::SFloat[0])
  @weight_initializer.init_param(self, @weight)
  @weight_regularizer.param = @weight if @weight_regularizer
end

#call(input_tensor) ⇒ Object



24
25
26
27
# File 'lib/dnn/core/embedding.rb', line 24

def call(input_tensor)
  build unless built?
  Tensor.new(forward(input_tensor.data), Link.new(nil, self))
end

#forward(x) ⇒ Object



36
37
38
39
40
41
42
43
# File 'lib/dnn/core/embedding.rb', line 36

def forward(x)
  @x = x
  y = Xumo::SFloat.zeros(*x.shape)
  x.shape[0].times do |i|
    y[i, false] = @weight.data[x[i, false]]
  end
  y
end

#get_paramsObject



88
89
90
# File 'lib/dnn/core/embedding.rb', line 88

def get_params
  { weight: @weight }
end

#load_hash(hash) ⇒ Object



82
83
84
85
86
# File 'lib/dnn/core/embedding.rb', line 82

def load_hash(hash)
  initialize(hash[:input_shape], hash[:input_length],
             weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
             weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]))
end

#regularizersObject



55
56
57
# File 'lib/dnn/core/embedding.rb', line 55

def regularizers
  @weight_regularizer ? [@weight_regularizer] : []
end

#to_hashObject



77
78
79
80
# File 'lib/dnn/core/embedding.rb', line 77

def to_hash
  super(input_shape: @input_shape, input_length: @input_length,
        weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)
end

#to_procObject



59
60
61
# File 'lib/dnn/core/embedding.rb', line 59

def to_proc
  method(:call).to_proc
end