Method: TensorStream::MathGradients._concat_grad_helper

Defined in:
lib/tensor_stream/math_gradients.rb

._concat_grad_helper(op, grad, start_value_index, end_value_index, dim_index) ⇒ Object



267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# File 'lib/tensor_stream/math_gradients.rb', line 267

def self._concat_grad_helper(op, grad, start_value_index, end_value_index, dim_index)
  # Degenerate concatenation, just return grad.
  if op.inputs.size == 2
    return end_value_index <= dim_index ? [grad] + [nil] : [nil] + [grad]
  end
  concat_dim = op.inputs[dim_index]
  input_values = op.inputs[start_value_index..end_value_index]
  non_neg_concat_dim = concat_dim % ts.rank(input_values[0])
  sizes = _extract_input_shapes(input_values)

  slicer = ts.slice(ts.stack(sizes, axis: 1), [non_neg_concat_dim, 0], [1, -1])
  sizes = ts.squeeze(slicer)

  out_grads = ts.split(grad, sizes, axis: non_neg_concat_dim, num: op.inputs.size - 1)
  end_value_index <= dim_index ? out_grads + [nil] : [nil] + out_grads
end