Class: RubyZero::NN::Model

Inherits:
Object
  • Object
show all
Defined in:
lib/rubyzero/nn/model.rb

Overview

Model class

Direct Known Subclasses

Layers::Layer, Losses::Loss

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initializeModel

Returns a new instance of Model.



4
5
# File 'lib/rubyzero/nn/model.rb', line 4

def initialize
end

Class Method Details

.load(path) ⇒ Object



69
70
71
72
73
# File 'lib/rubyzero/nn/model.rb', line 69

def self.load(path)
    File.open(path, "rb") do |f|
        return Marshal.load(f)
    end
end

Instance Method Details

#__get_str__(num_indents) ⇒ Object



28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/rubyzero/nn/model.rb', line 28

def __get_str__(num_indents)
    keys = instance_variables
    children = []
    keys.each do |key|
        obj = instance_variable_get(key)
        if obj.is_a?(RubyZero::NN::Model)
            children << obj
        end
    end
    indents = "  " * num_indents
    s = "#{indents}#{self.class.name} #{parameters.size} params\n"
    children.each do |child|
        s += child.__get_str__(num_indents + 1)
    end
    return s
end

#call(*args) ⇒ Object



9
10
11
# File 'lib/rubyzero/nn/model.rb', line 9

def call(*args)
    return forward(*args)
end

#evalObject



53
54
55
56
57
58
59
60
# File 'lib/rubyzero/nn/model.rb', line 53

def eval()
    def train()
        self.parameters.map do |param|
            param.requires_grad = false
        end
        return self
    end
end

#forwardObject



6
7
8
# File 'lib/rubyzero/nn/model.rb', line 6

def forward
    raise RubyZero::Core::Exceptions::NoInplementError, "#{self.class}.forward() method is not implemented"
end

#inspectObject



44
45
46
# File 'lib/rubyzero/nn/model.rb', line 44

def inspect
    return __get_str__(0)
end

#parametersObject



12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# File 'lib/rubyzero/nn/model.rb', line 12

def parameters
    param_keys = instance_variables
    params = []
    param_keys.each do |key|
        obj = instance_variable_get(key)
        if obj.is_a?(RubyZero::Core::Tensor)
            obj.requires_grad = true
            params << obj
        elsif obj.is_a?(RubyZero::NN::Parameters)
            params += obj.elements
        else obj.is_a?(RubyZero::NN::Model)
            params += obj.parameters.elements
        end
    end
    return Parameters.new(params)
end

#save(path) ⇒ Object



61
62
63
64
65
66
67
68
# File 'lib/rubyzero/nn/model.rb', line 61

def save(path)
    self.parameters.map do |param|
        param.grad = nil
    end
    File.open(path, "wb") do |f|
        Marshal.dump(self, f)
    end
end

#trainObject



47
48
49
50
51
52
# File 'lib/rubyzero/nn/model.rb', line 47

def train()
    self.parameters.map do |param|
        param.requires_grad = true
    end
    return self
end