Class: TensorStream::MathGradients

Inherits:
Object
  • Object
show all
Extended by:
OpHelper
Defined in:
lib/tensor_stream/math_gradients.rb

Overview

Class that provides auto-differentiation

Class Method Summary collapse

Methods included from OpHelper

cons, dtype_eval, fp_type?, i_cons, i_op, op, shape_eval, val_to_dtype

Class Method Details

._ds(tensor) ⇒ Object



133
134
135
136
137
138
139
140
141
142
# File 'lib/tensor_stream/math_gradients.rb', line 133

def self._ds(tensor)
  return tensor unless tensor.is_a?(Operation)

  case tensor.operation
  when :reduce_sum
    tensor.items[0]
  else
    tensor
  end
end

._grad_with_broadcast(tensor, wrt_dx, func, options) ⇒ Object



144
145
146
147
148
149
150
151
# File 'lib/tensor_stream/math_gradients.rb', line 144

def self._grad_with_broadcast(tensor, wrt_dx, func, options)
  grad = derivative(tensor.items[0], wrt_dx, options)
  grad2 = derivative(tensor.items[1], wrt_dx, options)
  elements1 = i_op(:reduce_prod, i_op(:shape, tensor.items[0]), data_type: :float32)
  elements2 = i_op(:reduce_prod, i_op(:shape, tensor.items[1]), data_type: :float32)
  multiplier = elements1 / elements2
  _reduce_when_necessary(func.call(grad, grad2 * multiplier), wrt_dx)
end

._include?(arr, obj) ⇒ Boolean

Returns:

  • (Boolean)


153
154
155
156
# File 'lib/tensor_stream/math_gradients.rb', line 153

def self._include?(arr, obj)
  arr.each { |a| return true if a.equal?(obj) }
  false
end

._reduce_when_necessary(tensor, wrt_dx) ⇒ Object



158
159
160
161
162
163
# File 'lib/tensor_stream/math_gradients.rb', line 158

def self._reduce_when_necessary(tensor, wrt_dx)
  rank = op(:rank, tensor)
  dx_rank = op(:rank, wrt_dx)
  reduced = op(:reduce_sum, tensor, nil, axis: 0)
  op(:cond, ->{ reduced }, tensor, pred: rank > dx_rank)
end

.derivative(tensor, wrt_dx, options = {}) ⇒ Object



6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# File 'lib/tensor_stream/math_gradients.rb', line 6

def self.derivative(tensor, wrt_dx, options = {})
  gradient_program_name = "_grad_#{tensor.name}_#{wrt_dx.name}"
  return options[:graph].get_node(gradient_program_name) if options[:graph] && options[:graph].node_added?(gradient_program_name)

  constant_options = { dtype: options[:dtype] }
  constant_options_1 = { dtype: options[:dtype] || tensor.data_type }

  return i_op(:ones_like, wrt_dx, constant_options_1) if tensor.equal?(wrt_dx)
  return i_cons(0, constant_options) if options[:stop_gradients] && _include?(options[:stop_gradients], tensor)

  if tensor.is_a?(Operation)
    grad = derivative(tensor.items[0], wrt_dx, options) if tensor.items[0]
    grad2 = derivative(tensor.items[1], wrt_dx, options) if tensor.items[1]

    case tensor.operation
    when :max
      x_mask = i_op(:where, i_op(:ones_like, tensor.items[0]), i_op(:zeros_like, tensor.items[1]), pred: tensor.items[0] > tensor.items[1])
      y_mask = i_op(:where, i_op(:zeros_like, tensor.items[0]), i_op(:ones_like, tensor.items[1]), pred: tensor.items[0] < tensor.items[1])
      x_mask * grad + y_mask * grad2
    when :where
      x_mask = i_op(:where, i_op(:ones_like, tensor.items[0]), i_op(:zeros_like, tensor.items[1]), pred: tensor.options[:pred])
      y_mask = i_op(:where, i_op(:zeros_like, tensor.items[0]), i_op(:ones_like, tensor.items[1]), pred: tensor.options[:pred])
      x_mask * grad + y_mask * grad2
    when :cond
      i_op(:cond, grad, grad2, pred: tensor.options[:pred])
    when :identity, :print, :pad
      grad
    when :negate
      i_cons(-1, constant_options_1) * grad
    when :abs
      grad * i_op(:sign, _ds(tensor.items[0]))
    when :square
      i_cons(2, constant_options_1) * _ds(tensor.items[0]) * grad
    when :exp
      i_op(:exp, tensor.items[0]) * grad
    when :log
      (i_cons(1, constant_options_1) / _ds(tensor.items[0])) * grad
    when :tanh
      (i_cons(1, constant_options_1) - (i_op(:tanh, _ds(tensor.items[0]))**2)) * grad
    when :tan
      (i_cons(1, constant_options_1) / (i_op(:cos, _ds(tensor.items[0]))**2)) * grad
    when :sin
      i_op(:cos, tensor.items[0]) * grad
    when :sqrt
      i_cons(1, constant_options_1) / (i_cons(2, constant_options_1) * i_op(:sqrt, _ds(tensor.items[0]))) * grad
    when :cos
      -i_op(:sin, tensor.items[0]) * grad
    when :add
      _grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:add, a, b, name: 'grad_sum') }, options)
    when :sub
      _grad_with_broadcast(tensor, wrt_dx, ->(a, b) { i_op(:sub, a, b, name: 'grad_sub') }, options)
    when :pow
      gx = _ds(tensor.items[1]) * (_ds(tensor.items[0])**(_ds(tensor.items[1]) - 1)) * grad

      log_x = i_op(:where, i_op(:log, tensor.items[0], nil, name: 'log_pow_grad'), i_op(:zeros_like, tensor.items[0]), pred: tensor.items[0] > 0)
      gy = _ds(tensor.items[0])**_ds(tensor.items[1]) * log_x * grad2

      gx + gy
    when :div
      # apply the quotient rule
      gx = i_op(:div, grad, _ds(tensor.items[1]))
      gy = grad2 * i_op(:div, i_op(:div, -_ds(tensor.items[0]), _ds(tensor.items[1])), _ds(tensor.items[1]))

      _reduce_when_necessary(gx + gy, wrt_dx)
    when :mul
      # apply the product rule
      _reduce_when_necessary(grad * _ds(tensor.items[1]) + _ds(tensor.items[0]) * grad2, wrt_dx)
    when :reduce_mean
      input_size = i_op(:reduce_prod, i_op(:shape, tensor.items[0]))
      output_size = i_op(:reduce_prod, i_op(:shape, tensor))
      factor = input_size / output_size

      (grad / i_op(:cast, factor, data_type: grad.dtype))
    when :reduce_sum
      grad
    when :stop_gradient
      return i_cons(0, constant_options)
    when :matmul
      derivative_a = derivative(tensor.items[0], wrt_dx)
      derivative_b = derivative(tensor.items[1], wrt_dx)

      s0 =  i_op(:shape, tensor.items[0])
      s1 =  i_op(:shape, tensor.items[1])

      identity_0 = i_op(:ones, [s0[0], s1[1]], nil, data_type: tensor.items[0].data_type)
      identity_1 = i_op(:ones, [s0[0], s1[1]], nil, data_type: tensor.items[1].data_type)

      matmul_da = i_op(:matmul, identity_0, tensor.items[1], transpose_b: true,
                                                             pad_zeros: true,
                                                             name:        'matrix_dx')
      matmul_db = i_op(:matmul, tensor.items[0], identity_1, transpose_a: true,
                                                             pad_zeros: true,
                                                             name:        'matrix_dy')

      zero_vect = i_op(:zeros_like, wrt_dx, nil, name: 'zero_vect')

      # matmul_db = op(:transpose, matmul_db, nil).first

      # begin_a = op(:zeros, op(:rank, matmul_db), nil, data_type: :int32, name: 'begin_a')
      # matmul_b_shape = op(:shape, matmul_db)
      # end_a = [matmul_b_shape[0], 1]

      matmul_da = i_op(:cond, matmul_da[0], matmul_da, pred: op(:rank, derivative_a) > 0)

      # matmul_da = op(:cond, matmul_da[0], matmul_da, pred: op(:rank, derivative_a) > 0)
      norm_a = i_op(:mul, derivative_a, matmul_da, name: 'grad_a_norm_mul_da')
      norm_b = i_op(:mul, derivative_b, matmul_db, name: 'grad_b_norm_mul_db')

      # norm_a = i_op(:cond, norm_a[0], norm_a, pred: i_op(:rank, matmul_da) > i_op(:rank, derivative_a))
      # norm_b = i_op(:cond, norm_b[0], norm_b, pred: i_op(:rank, matmul_db) > i_op(:rank, derivative_b))

      i_op(:cond, norm_a, zero_vect, pred: i_op(:reduce_sum, norm_a) != 0) + i_op(:cond, norm_b, zero_vect, pred: i_op(:reduce_sum, norm_b) != 0)
      #  m.breakpoint! { |t, a, b, v|  binding.pry }
    else
      raise "no derivative implementation found for op #{tensor.operation}"
    end
  elsif tensor.is_a?(TensorStream::Variable)
    i_cons(0, constant_options)
  elsif tensor.is_a?(TensorStream::Placeholder)
    i_cons(0, constant_options)
  else
    i_cons(0, constant_options)
  end.tap do |ops|
    options[:graph].add_node!(gradient_program_name, ops) if options[:graph]
  end
end