Class: DNN::Param

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/param.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(data = nil, grad = nil) ⇒ Param

Returns a new instance of Param.



7
8
9
10
11
# File 'lib/dnn/core/param.rb', line 7

def initialize(data = nil, grad = nil)
  @data = data
  @grad = grad
  @trainable = true
end

Instance Attribute Details

#dataObject

Returns the value of attribute data.



4
5
6
# File 'lib/dnn/core/param.rb', line 4

def data
  @data
end

#gradObject

Returns the value of attribute grad.



5
6
7
# File 'lib/dnn/core/param.rb', line 5

def grad
  @grad
end

#trainableObject

Returns the value of attribute trainable.



3
4
5
# File 'lib/dnn/core/param.rb', line 3

def trainable
  @trainable
end

Instance Method Details

#backward(grad) ⇒ Object



13
14
15
16
17
18
19
20
21
22
23
24
25
26
# File 'lib/dnn/core/param.rb', line 13

def backward(grad)
  if @trainable
    @grad ||= Xumo::SFloat[0]
    if @data.shape == grad.shape
      @grad += grad
    elsif @data.shape == grad.shape[1..-1]
      @grad += grad.sum(0)
    else
      raise DNN_Error, "Shape is missmatch."
    end
  else
    @grad = Xumo::SFloat[0]
  end
end