Class: DNN::Losses::Loss
- Inherits:
-
Object
show all
- Defined in:
- lib/dnn/core/losses.rb
Class Method Summary
collapse
Instance Method Summary
collapse
Class Method Details
.call(y, t, *args) ⇒ Object
5
6
7
|
# File 'lib/dnn/core/losses.rb', line 5
def self.call(y, t, *args)
new(*args).(y, t)
end
|
.from_hash(hash) ⇒ Object
9
10
11
12
13
14
15
16
|
# File 'lib/dnn/core/losses.rb', line 9
def self.from_hash(hash)
return nil unless hash
loss_class = DNN.const_get(hash[:class])
loss = loss_class.allocate
raise DNN_Error, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
loss.load_hash(hash)
loss
end
|
Instance Method Details
#call(y, t) ⇒ Object
18
19
20
|
# File 'lib/dnn/core/losses.rb', line 18
def call(y, t)
forward(y, t)
end
|
#clean ⇒ Object
60
61
62
63
64
65
66
|
# File 'lib/dnn/core/losses.rb', line 60
def clean
hash = to_hash
instance_variables.each do |ivar|
instance_variable_set(ivar, nil)
end
load_hash(hash)
end
|
#forward(y, t) ⇒ Object
31
32
33
|
# File 'lib/dnn/core/losses.rb', line 31
def forward(y, t)
raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
end
|
#load_hash(hash) ⇒ Object
56
57
58
|
# File 'lib/dnn/core/losses.rb', line 56
def load_hash(hash)
initialize
end
|
#loss(y, t, layers = nil) ⇒ Object
22
23
24
25
26
27
28
29
|
# File 'lib/dnn/core/losses.rb', line 22
def loss(y, t, layers = nil)
unless y.shape == t.shape
raise DNN_ShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
end
loss = call(y, t)
loss = regularizers_forward(loss, layers) if layers
loss
end
|
#regularizers_backward(layers) ⇒ Object
44
45
46
47
48
|
# File 'lib/dnn/core/losses.rb', line 44
def regularizers_backward(layers)
layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
layer.regularizers.each(&:backward)
end
end
|
#regularizers_forward(loss, layers) ⇒ Object
35
36
37
38
39
40
41
42
|
# File 'lib/dnn/core/losses.rb', line 35
def regularizers_forward(loss, layers)
regularizers = layers.select { |layer| layer.respond_to?(:regularizers) }
.map(&:regularizers).flatten
regularizers.each do |regularizer|
loss = regularizer.forward(loss)
end
loss
end
|
#to_hash(merge_hash = nil) ⇒ Object
50
51
52
53
54
|
# File 'lib/dnn/core/losses.rb', line 50
def to_hash(merge_hash = nil)
hash = { class: self.class.name }
hash.merge!(merge_hash) if merge_hash
hash
end
|