Method: TensorStream::MathGradients.derivative

Defined in:
lib/tensor_stream/math_gradients.rb

.derivative(tensor, wrt_dx, options = {}) ⇒ Object



11
12
13
14
15
16
17
18
19
20
21
22
23
# File 'lib/tensor_stream/math_gradients.rb', line 11

def self.derivative(tensor, wrt_dx, options = {})
  return i_op(:ones_like, tensor) if tensor.equal?(wrt_dx)
  return i_op(:zeros_like, wrt_dx) unless wrt_dx.consumers.include?(tensor.name)

  nodes_to_compute = wrt_dx.consumers.select { |t|
    node = tensor.graph.nodes[t]
    node.consumers.include?(tensor.name) || node.equal?(tensor)
  }.compact + [wrt_dx.name]

  grad = i_op(:fill, ts.shape(tensor), ts.constant(1, dtype: wrt_dx.data_type))

  _propagate(grad, tensor, wrt_dx, nodes_to_compute, options[:stop_gradients] || []) || i_op(:zeros_like, wrt_dx)
end