49
50
51
52
53
54
55
56
57
58
59
|
# File 'lib/tensorflow/train/optimizer.rb', line 49
def compute_gradients(loss, var_list: nil, grad_loss: nil)
trainable_vars = var_list || self.graph.get_collection_ref(Tensorflow::Graph::GraphKeys::TRAINABLE_VARIABLES)
if trainable_vars.nil? || trainable_vars.empty?
raise(Error::InvalidArgumentError, 'There are no variables to train for the loss function')
end
gradients = Graph::Gradients.new(graph)
grads = gradients.gradients(loss, trainable_vars, grad_ys: grad_loss)
grads.zip(trainable_vars)
end
|