Class: NeuralNetworkRb::Loss::CrossEntropyFetch

Inherits:
Object
  • Object
show all
Defined in:
lib/neural_network_rb/loss/cross_entropy_fetch.rb

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(next_layer, options = {}, &block) ⇒ CrossEntropyFetch

Returns a new instance of CrossEntropyFetch.



7
8
9
10
11
12
# File 'lib/neural_network_rb/loss/cross_entropy_fetch.rb', line 7

def initialize(next_layer, options = {}, &block)
  @next_layer = next_layer
  @every = options[:every]
  @block = block if block_given?
  @epoch = 0
end

Instance Attribute Details

#epochObject (readonly)

Returns the value of attribute epoch.



5
6
7
# File 'lib/neural_network_rb/loss/cross_entropy_fetch.rb', line 5

def epoch
  @epoch
end

Instance Method Details

#predict(input) ⇒ Object



21
22
23
# File 'lib/neural_network_rb/loss/cross_entropy_fetch.rb', line 21

def predict(input)
  @next_layer.nil? ? input : @next_layer.predict(input)
end

#train(input, target) ⇒ Object



14
15
16
17
18
19
# File 'lib/neural_network_rb/loss/cross_entropy_fetch.rb', line 14

def train(input, target)
  error = cross_entropy(input, target)
  @block.call(@epoch, error) if @block && (@epoch % @every == 0)
  @epoch +=1
  @next_layer.nil? ? 1 : @next_layer.train(input, target)
end