Method: TensorStream::Ops#gradients
- Defined in:
- lib/tensor_stream/ops.rb
#gradients(tensor_ys, wrt_xs, name: "gradients", stop_gradients: nil) ⇒ Object
Constructs symbolic derivatives of ys of input w.r.t. x in wrt_xs.
ys and xs are each a Tensor or a list of tensors. grad_ys is a list of Tensor, holding the gradients received by the ys. The list must be the same length as ys.
Arguments: tensor_ys : A Tensor or list of tensors to be differentiated. wrt_xs : A Tensor or list of tensors to be used for differentiation. stop_gradients : Optional. A Tensor or list of tensors not to differentiate through
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# File 'lib/tensor_stream/ops.rb', line 42 def gradients(tensor_ys, wrt_xs, name: "gradients", stop_gradients: nil) tensor_ys = tensor_ys.op gs = wrt_xs.map(&:op).collect { |x| stops = stop_gradients ? stop_gradients.map(&:name).join("_") : "" gradient_program_name = "grad_#{tensor_ys.name}_#{x.name}_#{stops}".to_sym tensor_graph = tensor_ys.graph tensor_program = if tensor_graph.node_added?(gradient_program_name) tensor_graph.get_node(gradient_program_name) else tensor_graph.name_scope("gradient_wrt_#{x.name}") do derivative_ops = TensorStream::MathGradients.derivative(tensor_ys, x, graph: tensor_graph, stop_gradients: stop_gradients) tensor_graph.add_node!(gradient_program_name, derivative_ops) end end tensor_program } gs end |