Class: DNN::Layers::Embedding
Instance Attribute Summary collapse
#trainable
Attributes inherited from Layer
#input_shape
Instance Method Summary
collapse
#clean
Methods inherited from Layer
#built?, call, #clean, from_hash, #output_shape
Constructor Details
#initialize(input_dim_or_shape, input_length, weight_initializer: Initializers::RandomUniform.new, weight_regularizer: nil) ⇒ Embedding
Returns a new instance of Embedding.
14
15
16
17
18
19
20
21
22
23
|
# File 'lib/dnn/core/layers/embedding.rb', line 14
def initialize(input_dim_or_shape, input_length,
weight_initializer: Initializers::RandomUniform.new,
weight_regularizer: nil)
super()
@input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
@input_length = input_length
@weight_initializer = weight_initializer
@weight_regularizer = weight_regularizer
@weight = Param.new(nil, Xumo::SFloat[0])
end
|
Instance Attribute Details
Returns the value of attribute input_length.
5
6
7
|
# File 'lib/dnn/core/layers/embedding.rb', line 5
def input_length
@input_length
end
|
#weight ⇒ Object
Returns the value of attribute weight.
6
7
8
|
# File 'lib/dnn/core/layers/embedding.rb', line 6
def weight
@weight
end
|
#weight_initializer ⇒ Object
Returns the value of attribute weight_initializer.
7
8
9
|
# File 'lib/dnn/core/layers/embedding.rb', line 7
def weight_initializer
@weight_initializer
end
|
#weight_regularizer ⇒ Object
Returns the value of attribute weight_regularizer.
8
9
10
|
# File 'lib/dnn/core/layers/embedding.rb', line 8
def weight_regularizer
@weight_regularizer
end
|
Instance Method Details
#<<(layer) ⇒ Object
71
72
73
74
75
76
|
# File 'lib/dnn/core/layers/embedding.rb', line 71
def <<(layer)
if RUBY_VERSION < "2.6.0"
raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
end
to_proc << layer
end
|
#>>(layer) ⇒ Object
64
65
66
67
68
69
|
# File 'lib/dnn/core/layers/embedding.rb', line 64
def >>(layer)
if RUBY_VERSION < "2.6.0"
raise DNN_Error, "Function composition is not supported before ruby version 2.6.0."
end
to_proc >> layer
end
|
#backward(dy) ⇒ Object
46
47
48
49
50
51
52
53
54
|
# File 'lib/dnn/core/layers/embedding.rb', line 46
def backward(dy)
@weight.grad += Xumo::SFloat.zeros(*@weight.data.shape)
@x.shape[0].times do |i|
@x.shape[1].times do |j|
@weight.grad[@x[i, j]] += dy[i, j]
end
end
nil
end
|
#build(input_shape) ⇒ Object
30
31
32
33
34
35
|
# File 'lib/dnn/core/layers/embedding.rb', line 30
def build(input_shape)
@built = true
@weight.data = Xumo::SFloat.new(@input_length)
@weight_initializer.init_param(self, @weight)
@weight_regularizer.param = @weight if @weight_regularizer
end
|
#call(input_tensor) ⇒ Object
25
26
27
28
|
# File 'lib/dnn/core/layers/embedding.rb', line 25
def call(input_tensor)
build(@input_shape) unless built?
Tensor.new(forward(input_tensor.data), Link.new(nil, self))
end
|
#forward(x) ⇒ Object
37
38
39
40
41
42
43
44
|
# File 'lib/dnn/core/layers/embedding.rb', line 37
def forward(x)
@x = x
y = Xumo::SFloat.zeros(*x.shape)
x.shape[0].times do |i|
y[i, false] = @weight.data[x[i, false]]
end
y
end
|
#get_params ⇒ Object
89
90
91
|
# File 'lib/dnn/core/layers/embedding.rb', line 89
def get_params
{ weight: @weight }
end
|
#load_hash(hash) ⇒ Object
83
84
85
86
87
|
# File 'lib/dnn/core/layers/embedding.rb', line 83
def load_hash(hash)
initialize(hash[:input_shape], hash[:input_length],
weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]))
end
|
#regularizers ⇒ Object
56
57
58
|
# File 'lib/dnn/core/layers/embedding.rb', line 56
def regularizers
@weight_regularizer ? [@weight_regularizer] : []
end
|
#to_hash ⇒ Object
78
79
80
81
|
# File 'lib/dnn/core/layers/embedding.rb', line 78
def to_hash
super(input_shape: @input_shape, input_length: @input_length,
weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)
end
|
#to_proc ⇒ Object
60
61
62
|
# File 'lib/dnn/core/layers/embedding.rb', line 60
def to_proc
method(:call).to_proc
end
|