221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
|
# File 'lib/tensor_stream/math_gradients.rb', line 221
def self._min_or_max_grad(inputs, grad, selector_op)
x = inputs[0]
y = inputs[1]
gdtype = grad.data_type
sx = ts.shape(x)
sy = ts.shape(y)
gradshape = ts.shape(grad)
zeros = ts.zeros(gradshape, dtype: gdtype)
xmask = selector_op.call(x, y)
rx, ry = _broadcast_gradient_args(sx, sy)
xgrad = ts.where(xmask, grad, zeros, name: "x")
ygrad = ts.where(xmask, zeros, grad, name: "y")
gx = ts.reshape(ts.reduce_sum(xgrad, rx), sx)
gy = ts.reshape(ts.reduce_sum(ygrad, ry), sy)
[gx, gy]
end
|