Class: DNN::Param

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/param.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(data = nil, grad = nil) ⇒ Param

Returns a new instance of Param.



7
8
9
10
11
# File 'lib/dnn/core/param.rb', line 7

def initialize(data = nil, grad = nil)
  @data = data
  @grad = grad
  @trainable = true
end

Instance Attribute Details

#dataObject

Returns the value of attribute data.



4
5
6
# File 'lib/dnn/core/param.rb', line 4

def data
  @data
end

#gradObject

Returns the value of attribute grad.



5
6
7
# File 'lib/dnn/core/param.rb', line 5

def grad
  @grad
end

#trainableObject

Returns the value of attribute trainable.



3
4
5
# File 'lib/dnn/core/param.rb', line 3

def trainable
  @trainable
end

Instance Method Details

#*(other) ⇒ Object



50
51
52
53
# File 'lib/dnn/core/param.rb', line 50

def *(other)
  other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
  Layers::Mul.(self, other)
end

#**(index) ⇒ Object



60
61
62
# File 'lib/dnn/core/param.rb', line 60

def **(index)
  Layers::Pow.new(index).(self)
end

#+(other) ⇒ Object



40
41
42
43
# File 'lib/dnn/core/param.rb', line 40

def +(other)
  other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
  Layers::Add.(self, other)
end

#+@Object



32
33
34
# File 'lib/dnn/core/param.rb', line 32

def +@
  self
end

#-(other) ⇒ Object



45
46
47
48
# File 'lib/dnn/core/param.rb', line 45

def -(other)
  other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
  Layers::Sub.(self, other)
end

#-@Object



36
37
38
# File 'lib/dnn/core/param.rb', line 36

def -@
  Neg.(self)
end

#/(other) ⇒ Object



55
56
57
58
# File 'lib/dnn/core/param.rb', line 55

def /(other)
  other = Tensor.convert(other) unless other.is_a?(DNN::Tensor) || other.is_a?(DNN::Param)
  Layers::Div.(self, other)
end

#backward(grad) ⇒ Object



13
14
15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/dnn/core/param.rb', line 13

def backward(grad)
  if @trainable
    @grad ||= Xumo::SFloat[0]
    if @data.shape == grad.shape
      @grad += grad
    elsif @data.shape == grad.shape[1..-1]
      @grad += grad.sum(0)
    else
      raise DNNError, "Shape is missmatch."
    end
  else
    @grad = Xumo::SFloat[0]
  end
end

#shapeObject



28
29
30
# File 'lib/dnn/core/param.rb', line 28

def shape
  @data.shape
end