Method: TensorStream::MathGradients._sum_grad

Defined in:
lib/tensor_stream/math_gradients.rb

._sum_grad(arg_x, arg_y, grad) ⇒ Object



205
206
207
208
209
210
211
212
213
214
# File 'lib/tensor_stream/math_gradients.rb', line 205

def self._sum_grad(arg_x, arg_y, grad)
  input_shape = _op(:shape, arg_x)
  output_shape_kept_dims = ts.reduced_shape(input_shape, arg_y)
  tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims)
  new_grad = _op(:reshape, grad, output_shape_kept_dims)

  grad = _op(:case, [_op(:rank, grad).zero?], _op(:tile, new_grad, tile_scaling), _op(:fill, input_shape, grad))

  [grad, nil]
end