Method: TensorStream::MathGradients._sum_grad
- Defined in:
- lib/tensor_stream/math_gradients.rb
._sum_grad(arg_x, arg_y, grad) ⇒ Object
205 206 207 208 209 210 211 212 213 214 |
# File 'lib/tensor_stream/math_gradients.rb', line 205 def self._sum_grad(arg_x, arg_y, grad) input_shape = _op(:shape, arg_x) output_shape_kept_dims = ts.reduced_shape(input_shape, arg_y) tile_scaling = _safe_shape_div(input_shape, output_shape_kept_dims) new_grad = _op(:reshape, grad, output_shape_kept_dims) grad = _op(:case, [_op(:rank, grad).zero?], _op(:tile, new_grad, tile_scaling), _op(:fill, input_shape, grad)) [grad, nil] end |