Class: DNN::Losses::Loss

Inherits:
Object
  • Object
show all
Defined in:
lib/dnn/core/losses.rb

Class Method Summary collapse

Instance Method Summary collapse

Class Method Details

.call(y, t, *args) ⇒ Object



5
6
7
# File 'lib/dnn/core/losses.rb', line 5

def self.call(y, t, *args)
  new(*args).(y, t)
end

.from_hash(hash) ⇒ Object

Raises:



9
10
11
12
13
14
15
16
# File 'lib/dnn/core/losses.rb', line 9

def self.from_hash(hash)
  return nil unless hash
  loss_class = DNN.const_get(hash[:class])
  loss = loss_class.allocate
  raise DNNError, "#{loss.class} is not an instance of #{self} class." unless loss.is_a?(self)
  loss.load_hash(hash)
  loss
end

Instance Method Details

#call(y, t) ⇒ Object



18
19
20
# File 'lib/dnn/core/losses.rb', line 18

def call(y, t)
  forward(y, t)
end

#cleanObject



55
56
57
58
59
60
61
# File 'lib/dnn/core/losses.rb', line 55

def clean
  hash = to_hash
  instance_variables.each do |ivar|
    instance_variable_set(ivar, nil)
  end
  load_hash(hash)
end

#forward(y, t) ⇒ Object

Raises:

  • (NotImplementedError)


32
33
34
# File 'lib/dnn/core/losses.rb', line 32

def forward(y, t)
  raise NotImplementedError, "Class '#{self.class.name}' has implement method 'forward'"
end

#load_hash(hash) ⇒ Object



51
52
53
# File 'lib/dnn/core/losses.rb', line 51

def load_hash(hash)
  initialize
end

#loss(y, t, layers: nil, loss_weight: nil) ⇒ Object



22
23
24
25
26
27
28
29
30
# File 'lib/dnn/core/losses.rb', line 22

def loss(y, t, layers: nil, loss_weight: nil)
  unless y.shape == t.shape
    raise DNNShapeError, "The shape of y does not match the t shape. y shape is #{y.shape}, but t shape is #{t.shape}."
  end
  loss = call(y, t)
  loss *= loss_weight if loss_weight
  loss = regularizers_forward(loss, layers) if layers
  loss
end

#regularizers_forward(loss, layers) ⇒ Object



36
37
38
39
40
41
42
43
# File 'lib/dnn/core/losses.rb', line 36

def regularizers_forward(loss, layers)
  regularizers = layers.select { |layer| layer.respond_to?(:regularizers) }
                       .map(&:regularizers).flatten
  regularizers.each do |regularizer|
    loss = regularizer.forward(loss)
  end
  loss
end

#to_hash(merge_hash = nil) ⇒ Object



45
46
47
48
49
# File 'lib/dnn/core/losses.rb', line 45

def to_hash(merge_hash = nil)
  hash = { class: self.class.name }
  hash.merge!(merge_hash) if merge_hash
  hash
end