267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
|
# File 'lib/tensor_stream/math_gradients.rb', line 267
def self._concat_grad_helper(op, grad, start_value_index, end_value_index, dim_index)
if op.inputs.size == 2
return end_value_index <= dim_index ? [grad] + [nil] : [nil] + [grad]
end
concat_dim = op.inputs[dim_index]
input_values = op.inputs[start_value_index..end_value_index]
non_neg_concat_dim = concat_dim % ts.rank(input_values[0])
sizes = (input_values)
slicer = ts.slice(ts.stack(sizes, axis: 1), [non_neg_concat_dim, 0], [1, -1])
sizes = ts.squeeze(slicer)
out_grads = ts.split(grad, sizes, axis: non_neg_concat_dim, num: op.inputs.size - 1)
end_value_index <= dim_index ? out_grads + [nil] : [nil] + out_grads
end
|