Class: DNN::Layers::Embedding

Inherits:
TrainableLayer show all
Includes:
LayerNode
Defined in:
lib/dnn/core/layers/embedding.rb

Instance Attribute Summary collapse

Attributes inherited from TrainableLayer

#trainable

Attributes inherited from Layer

#input_shape, #output_shape

Instance Method Summary collapse

Methods included from LayerNode

#forward

Methods inherited from TrainableLayer

#clean

Methods inherited from Layer

#built?, #call, call, #clean, #compute_output_shape, #forward, from_hash

Constructor Details

#initialize(input_dim_or_shape, input_length, weight_initializer: Initializers::RandomUniform.new, weight_regularizer: nil) ⇒ Embedding

Returns a new instance of Embedding.

Parameters:

  • input_dim_or_shape (Integer | Array)

    Set input data dimension or shape.

  • input_length (Integer)

    Set the time series length of input data.

  • weight_initializer (DNN::Initializers::Initializer) (defaults to: Initializers::RandomUniform.new)

    Weight initializer.

  • weight_regularizer (DNN::Regularizers::Regularizer | NilClass) (defaults to: nil)

    Weight regularizer.



16
17
18
19
20
21
22
23
24
25
# File 'lib/dnn/core/layers/embedding.rb', line 16

def initialize(input_dim_or_shape, input_length,
               weight_initializer: Initializers::RandomUniform.new,
               weight_regularizer: nil)
  super()
  @input_shape = input_dim_or_shape.is_a?(Array) ? input_dim_or_shape : [input_dim_or_shape]
  @input_length = input_length
  @weight_initializer = weight_initializer
  @weight_regularizer = weight_regularizer
  @weight = Param.new(nil, Xumo::SFloat[0])
end

Instance Attribute Details

#input_lengthObject (readonly)

Returns the value of attribute input_length.



7
8
9
# File 'lib/dnn/core/layers/embedding.rb', line 7

def input_length
  @input_length
end

#weightObject (readonly)

Returns the value of attribute weight.



8
9
10
# File 'lib/dnn/core/layers/embedding.rb', line 8

def weight
  @weight
end

#weight_initializerObject (readonly)

Returns the value of attribute weight_initializer.



9
10
11
# File 'lib/dnn/core/layers/embedding.rb', line 9

def weight_initializer
  @weight_initializer
end

#weight_regularizerObject (readonly)

Returns the value of attribute weight_regularizer.



10
11
12
# File 'lib/dnn/core/layers/embedding.rb', line 10

def weight_regularizer
  @weight_regularizer
end

Instance Method Details

#backward_node(dy) ⇒ Object



43
44
45
46
47
48
49
50
51
# File 'lib/dnn/core/layers/embedding.rb', line 43

def backward_node(dy)
  @weight.grad += Xumo::SFloat.zeros(*@weight.data.shape)
  @x.shape[0].times do |i|
    @x.shape[1].times do |j|
      @weight.grad[@x[i, j]] += dy[i, j]
    end
  end
  nil
end

#build(input_shape) ⇒ Object



27
28
29
30
31
32
# File 'lib/dnn/core/layers/embedding.rb', line 27

def build(input_shape)
  super(@input_shape)
  @weight.data = Xumo::SFloat.new(@input_length)
  @weight_initializer.init_param(self, @weight)
  @weight_regularizer.param = @weight if @weight_regularizer
end

#forward_node(x) ⇒ Object



34
35
36
37
38
39
40
41
# File 'lib/dnn/core/layers/embedding.rb', line 34

def forward_node(x)
  @x = x
  y = Xumo::SFloat.zeros(*x.shape)
  x.shape[0].times do |i|
    y[i, false] = @weight.data[x[i, false]]
  end
  y
end

#get_paramsObject



68
69
70
# File 'lib/dnn/core/layers/embedding.rb', line 68

def get_params
  { weight: @weight }
end

#load_hash(hash) ⇒ Object



62
63
64
65
66
# File 'lib/dnn/core/layers/embedding.rb', line 62

def load_hash(hash)
  initialize(hash[:input_shape], hash[:input_length],
             weight_initializer: Initializers::Initializer.from_hash(hash[:weight_initializer]),
             weight_regularizer: Regularizers::Regularizer.from_hash(hash[:weight_regularizer]))
end

#regularizersObject



53
54
55
# File 'lib/dnn/core/layers/embedding.rb', line 53

def regularizers
  @weight_regularizer ? [@weight_regularizer] : []
end

#to_hashObject



57
58
59
60
# File 'lib/dnn/core/layers/embedding.rb', line 57

def to_hash
  super(input_shape: @input_shape, input_length: @input_length,
        weight_initializer: @weight_initializer.to_hash, weight_regularizer: @weight_regularizer&.to_hash)
end