Class: DNN::Losses::Loss
- Inherits:
-
Object
show all
- Defined in:
- lib/dnn/core/losses.rb
Class Method Summary
collapse
Instance Method Summary
collapse
Class Method Details
.from_hash(hash) ⇒ Object
5
6
7
8
9
10
11
12
|
# File 'lib/dnn/core/losses.rb', line 5
def self.from_hash(hash)
return nil unless hash
loss_class = DNN.const_get(hash[:class])
loss = loss_class.allocate
raise DNN_Error, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
loss.load_hash(hash)
loss
end
|
Instance Method Details
#backward(y, t) ⇒ Object
27
28
29
|
# File 'lib/dnn/core/losses.rb', line 27
def backward(y, t)
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'backward'"
end
|
#forward(y, t) ⇒ Object
23
24
25
|
# File 'lib/dnn/core/losses.rb', line 23
def forward(y, t)
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
end
|
#load_hash(hash) ⇒ Object
53
54
55
|
# File 'lib/dnn/core/losses.rb', line 53
def load_hash(hash)
initialize
end
|
#loss(y, t, layers = nil) ⇒ Object
14
15
16
17
18
19
20
21
|
# File 'lib/dnn/core/losses.rb', line 14
def loss(y, t, layers = nil)
unless y.shape == t.shape
raise DNN_ShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
end
loss_value = forward(y, t)
loss_value += regularizers_forward(layers) if layers
loss_value.is_a?(Float) ? loss_value : loss_value.sum
end
|
#regularizers_backward(layers) ⇒ Object
41
42
43
44
45
|
# File 'lib/dnn/core/losses.rb', line 41
def regularizers_backward(layers)
layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
layer.regularizers.each(&:backward)
end
end
|
#regularizers_forward(layers) ⇒ Object
31
32
33
34
35
36
37
38
39
|
# File 'lib/dnn/core/losses.rb', line 31
def regularizers_forward(layers)
loss_value = 0
regularizers = layers.select { |layer| layer.respond_to?(:regularizers) }
.map(&:regularizers).flatten
regularizers.each do |regularizer|
loss_value = regularizer.forward(loss_value)
end
loss_value
end
|
#to_hash(merge_hash = nil) ⇒ Object
47
48
49
50
51
|
# File 'lib/dnn/core/losses.rb', line 47
def to_hash(merge_hash = nil)
hash = { class: self.class.name }
hash.merge!(merge_hash) if merge_hash
hash
end
|