Class: DNN::Layers::Split

Inherits:
Layer
  • Object
show all
Includes:
LayerNode
Defined in:
lib/dnn/core/layers/split_layers.rb

Instance Attribute Summary collapse

Attributes inherited from Layer

#input_shape, #output_shape

Instance Method Summary collapse

Methods included from LayerNode

#forward

Methods inherited from Layer

#<<, #build, #built?, #call, call, #clean, #compute_output_shape, #forward, from_hash

Constructor Details

#initialize(axis: 1, dim: nil) ⇒ Split

Returns a new instance of Split.

Raises:



10
11
12
13
14
15
# File 'lib/dnn/core/layers/split_layers.rb', line 10

def initialize(axis: 1, dim: nil)
  super()
  raise DNNError, "dim is nil" if dim == nil
  @axis = axis
  @dim = dim
end

Instance Attribute Details

#axisObject (readonly)

Returns the value of attribute axis.



7
8
9
# File 'lib/dnn/core/layers/split_layers.rb', line 7

def axis
  @axis
end

#dimObject (readonly)

Returns the value of attribute dim.



8
9
10
# File 'lib/dnn/core/layers/split_layers.rb', line 8

def dim
  @dim
end

Instance Method Details

#backward_node(dy1, dy2) ⇒ Object



25
26
27
# File 'lib/dnn/core/layers/split_layers.rb', line 25

def backward_node(dy1, dy2)
  dy1.concatenate(dy2, axis: @axis)
end

#forward_node(x) ⇒ Object



17
18
19
20
21
22
23
# File 'lib/dnn/core/layers/split_layers.rb', line 17

def forward_node(x)
  x1_dim = @dim
  x2_dim = x.shape[@axis] - @dim
  y1, y2others = x.split([x1_dim, x1_dim + x2_dim], axis: @axis)
  y2 = y2others.is_a?(Array) ? y2others[0].concatenate(y2others[1..-1], axis: @axis) : y2others
  [y1, y2]
end

#load_hash(hash) ⇒ Object



33
34
35
# File 'lib/dnn/core/layers/split_layers.rb', line 33

def load_hash(hash)
  initialize(axis: hash[:axis], dim: hash[:dim])
end

#to_hashObject



29
30
31
# File 'lib/dnn/core/layers/split_layers.rb', line 29

def to_hash
  super(axis: @axis, dim: @dim)
end