Class: DNN::Losses::Loss

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/losses.rb

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.from_hash(hash) ⇒ Object

Raises:



5
6
7
8
9
10
11
12
# File 'lib/dnn/core/losses.rb', line 5

def self.from_hash(hash)
  return nil unless hash
  loss_class = DNN.const_get(hash[:class])
  loss = loss_class.allocate
  raise DNN_Error, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
  loss.load_hash(hash)
  loss
end

Instance Method Details

#backward(y, t) ⇒ Object

Raises:

  • (NotImplementedError)


27
28
29
# File 'lib/dnn/core/losses.rb', line 27

def backward(y, t)
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'backward'"
end

#forward(y, t) ⇒ Object

Raises:

  • (NotImplementedError)


23
24
25
# File 'lib/dnn/core/losses.rb', line 23

def forward(y, t)
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
end

#load_hash(hash) ⇒ Object



53
54
55
# File 'lib/dnn/core/losses.rb', line 53

def load_hash(hash)
  initialize
end

#loss(y, t, layers = nil) ⇒ Object



14
15
16
17
18
19
20
21
# File 'lib/dnn/core/losses.rb', line 14

def loss(y, t, layers = nil)
  unless y.shape == t.shape
    raise DNN_ShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
  end
  loss_value = forward(y, t)
  loss_value += regularizers_forward(layers) if layers
  loss_value.is_a?(Float) ? loss_value : loss_value.sum
end

#regularizers_backward(layers) ⇒ Object



41
42
43
44
45
# File 'lib/dnn/core/losses.rb', line 41

def regularizers_backward(layers)
  layers.select { |layer| layer.respond_to?(:regularizers) }.each do |layer|
    layer.regularizers.each(&:backward)
  end
end

#regularizers_forward(layers) ⇒ Object



31
32
33
34
35
36
37
38
39
# File 'lib/dnn/core/losses.rb', line 31

def regularizers_forward(layers)
  loss_value = 0
  regularizers = layers.select { |layer| layer.respond_to?(:regularizers) }
                       .map(&:regularizers).flatten
  regularizers.each do |regularizer|
    loss_value = regularizer.forward(loss_value)
  end
  loss_value
end

#to_hash(merge_hash = nil) ⇒ Object



47
48
49
50
51
# File 'lib/dnn/core/losses.rb', line 47

def to_hash(merge_hash = nil)
  hash = { class: self.class.name }
  hash.merge!(merge_hash) if merge_hash
  hash
end