Module: TensorStream::Train::LearningRateDecay

Includes:
OpHelper, Ops, Utils
Included in:
TensorStream::Trainer
Defined in:
lib/tensor_stream/train/learning_rate_decay.rb

Constant Summary

Constants included from Ops

Ops::FLOATING_POINT_TYPES, Ops::INTEGER_TYPES, Ops::NUMERIC_TYPES

Instance Method Summary collapse

Methods included from Ops

#abs, #acos, #add_n, #asin, #assert_equal, #atan, #broadcast_gradient_args, #case, #cast, #cast_axis, #check_numerics, #clip_by_norm, #concat, #cond, #constant_initializer, #cumprod, #dynamic_partition, #exp, #expand_dims, #eye, #floor_div, #gather, #glorot_uniform_initializer, #gradients, #identity, #index, #invert_permutation, #log, #log1p, #logical_and, #maximum, #minimum, #multiply, #negative, #not_equal, #ones, #ones_initializer, #ones_like, #pack, #pad, #print, #random_normal, #random_uniform_initializer, #reciprocal, #reduce, #reduce_mean, #reshape, #sec, #setdiff1d, #shape_n, #slice, #split, #sqrt, #square, #squared_difference, #squeeze, #stack, #stop_gradient, #transpose, #truncated_normal, #unpack, #unstack, #where, #zeros_initializer, #zeros_like

Methods included from OpStub

#add, #argmax, #argmin, #ceil, #cos, #div, #equal, #expand_dims, #fill, #floor, #floor_div, #greater, #greater_equal, #less, #less_equal, #log, #mat_mul, #max, #min, #mod, #mul, #negate, #not_equal, #ones_like, #pow, #prod, #random_uniform, #range, #rank, #reshape, #round, #rsqrt, #shape, #sigmoid, #sign, #sin, #size, #strided_slice, #sub, #sum, #tan, #tanh, #tile, #top_k, #zeros

Methods included from OpHelper

#_op, #cons, #format_source, #fp_type?, #i_cons, #i_op, #i_var, #int_type?, #reduced_shape, #shape_eval, #shape_full_specified, #shapes_fully_specified_and_equal

Methods included from Utils

#__v_scope_name, #apply_data_type_coercion, #assign, #check_allowed_types, #check_data_types, #check_if_dense, #colocate_with, #constant, #control_dependencies, #convert_to_tensor, #device, #disable_eager_execution, #dynamic_stitch, #enable_eager_execution, #executing_eagerly?, #float32, #get_collection, #get_default_graph, #get_variable, #get_variable_scope, #global_variables_initializer, #graph, #group, #image, #layers, #list_local_devices, #math, #name_scope, #placeholder, #program, #reset_default_graph, #session, #set_random_seed, #train, #trainable_variables, #variable, #variable_scope

Instance Method Details

#exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase: false, name: nil) ⇒ Object

Applies exponential decay to the learning rate



12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/tensor_stream/train/learning_rate_decay.rb', line 12

def exponential_decay(learning_rate, global_step, decay_steps, decay_rate, staircase: false, name: nil)
  raise TensorStream::ValueError, "global_step is required for exponential_decay." if global_step.nil?

  name_scope(name, default: "ExponentialDecay", values: [learning_rate, global_step, decay_steps, decay_rate]) do
    learning_rate = convert_to_tensor(learning_rate, name: "learning_rate")
    data_type = learning_rate.data_type
    decay_steps = cast(decay_steps, data_type)
    decay_rate = cast(decay_rate, data_type)

    global_step_recomp = cast(global_step, data_type)
    p = global_step_recomp / decay_steps
    p = floor(p) if staircase
    multiply(learning_rate, pow(decay_rate, p), name: name)
  end
end