Class: DNN::Layers::MergeLayer

Inherits:
Layer
  • Object
show all
Includes:
MergeLayerNode
Defined in:
lib/dnn/core/layers/merge_layers.rb

Direct Known Subclasses

Add, Concatenate, Div, Dot, Mul, Sub

Instance Attribute Summary

Attributes inherited from Layer

#input_shape

Class Method Summary collapse

Instance Method Summary collapse

Methods included from MergeLayerNode

#backward, #backward_node, #forward, #forward_node

Methods inherited from Layer

#build, #built?, #clean, #forward, from_hash, #initialize, #load_hash, #output_shape, #to_hash

Constructor Details

This class inherits a constructor from DNN::Layers::Layer

Class Method Details

.call(x1, x2, *args) ⇒ Object



31
32
33
# File 'lib/dnn/core/layers/merge_layers.rb', line 31

def self.call(x1, x2, *args)
  new(*args).call(x1, x2)
end

Instance Method Details

#call(input_tensor1, input_tensor2) ⇒ Object



35
36
37
38
39
40
41
42
43
44
# File 'lib/dnn/core/layers/merge_layers.rb', line 35

def call(input_tensor1, input_tensor2)
  input_tensor1 = Tensor.new(input_tensor1) if !input_tensor1.is_a?(Tensor) && !input_tensor1.is_a?(Param)
  input_tensor2 = Tensor.new(input_tensor2) if !input_tensor2.is_a?(Tensor) && !input_tensor2.is_a?(Param)
  if input_tensor1.data.is_a?(Numo::NArray)
    build(input_tensor1.data.shape[1..-1]) unless built?
  else
    build([1]) unless built?
  end
  forward(input_tensor1, input_tensor2)
end