Method: TensorStream::MathGradients._min_or_max_grad

Defined in:
lib/tensor_stream/math_gradients.rb

._min_or_max_grad(inputs, grad, selector_op) ⇒ Object



221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
# File 'lib/tensor_stream/math_gradients.rb', line 221

def self._min_or_max_grad(inputs, grad, selector_op)
  x = inputs[0]
  y = inputs[1]
  gdtype = grad.data_type
  sx = ts.shape(x)
  sy = ts.shape(y)
  gradshape = ts.shape(grad)
  zeros = ts.zeros(gradshape, dtype: gdtype)
  xmask = selector_op.call(x, y)
  rx, ry = _broadcast_gradient_args(sx, sy)
  xgrad = ts.where(xmask, grad, zeros, name: "x")
  ygrad = ts.where(xmask, zeros, grad, name: "y")
  gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
  gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
  [gx, gy]
end