11
12
13
14
15
16
17
18
19
20
21
22
23
|
# File 'lib/tensor_stream/math_gradients.rb', line 11
def self.derivative(tensor, wrt_dx, options = {})
return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
return i_op(:zeros_like, wrt_dx) unless wrt_dx.consumers.include?(tensor.name)
nodes_to_compute = wrt_dx.consumers.select { |t|
node = tensor.graph.nodes[t]
node.consumers.include?(tensor.name) || node.equal?(tensor)
}.compact + [wrt_dx.name]
grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))
_propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || []) || i_op(:zeros_like, wrt_dx)
end
|