Method: TensorStream::Ops#gradients

Defined in:
lib/tensor_stream/ops.rb

#gradients(tensor_ys, wrt_xs, name: "gradients", stop_gradients: nil) ⇒ Object

Constructs symbolic derivatives of ys of input w.r.t. x in wrt_xs.

ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gradients received by the ys. The list must be the same length as ys.

Arguments: tensor_ys : A Tensor or list of tensors to be differentiated. wrt_xs : A Tensor or list of tensors to be used for differentiation. stop_gradients : Optional. A Tensor or list of tensors not to differentiate through



42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# File 'lib/tensor_stream/ops.rb', line 42

def gradients(tensor_ys, wrt_xs, name: "gradients", stop_gradients: nil)
  tensor_ys = tensor_ys.op
  gs = wrt_xs.map(&:op).collect { |x|
    stops = stop_gradients ? stop_gradients.map(&:name).join("_") : ""
    gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym
    tensor_graph = tensor_ys.graph

    tensor_program = if tensor_graph.node_added?(gradient_program_name)
      tensor_graph.get_node(gradient_program_name)
    else
      tensor_graph.name_scope("gradient_wrt_#{x.name}") do
        derivative_ops = TensorStream::MathGradients.derivative(tensor_ys, x, graph: tensor_graph,
                                                                              stop_gradients: stop_gradients)
        tensor_graph.add_node!(gradient_program_name, derivative_ops)
      end
    end
    tensor_program
  }

  gs
end