Class: Torch::NN::Module
- Inherits:
-
Object
show all
- Includes:
- Utils
- Defined in:
- lib/torch/nn/module.rb
Direct Known Subclasses
AvgPoolNd, BatchNorm, Bilinear, ConstantPadNd, ConvNd, CosineSimilarity, DropoutNd, Embedding, EmbeddingBag, Fold, GroupNorm, Hardshrink, Identity, LPPoolNd, LayerNorm, LeakyReLU, Linear, LocalResponseNorm, LogSigmoid, LogSoftmax, Loss, MaxPoolNd, MaxUnpoolNd, PReLU, PairwiseDistance, RNNBase, ReLU, ReflectionPadNd, ReplicationPadNd, Sequential, Sigmoid, Softmax, Softmax2d, Softmin, Softplus, Softshrink, Softsign, Tanh, Tanhshrink, Unfold
Instance Method Summary
collapse
Methods included from Utils
#_ntuple, #_pair, #_quadrupal, #_single, #_triple
Constructor Details
#initialize ⇒ Module
Returns a new instance of Module.
6
7
8
9
10
11
|
# File 'lib/torch/nn/module.rb', line 6
def initialize
@training = true
@parameters = {}
@buffers = {}
@modules = {}
end
|
Dynamic Method Handling
This class handles dynamic methods through the method_missing method
#method_missing(method, *args, &block) ⇒ Object
189
190
191
192
193
194
195
196
197
198
199
200
|
# File 'lib/torch/nn/module.rb', line 189
def method_missing(method, *args, &block)
name = method.to_s
if named_parameters.key?(name)
named_parameters[name]
elsif named_buffers.key?(name)
named_buffers[name]
elsif named_modules.key?(name)
named_modules[name]
else
super
end
end
|
Instance Method Details
#_apply(fn) ⇒ Object
33
34
35
36
37
38
39
|
# File 'lib/torch/nn/module.rb', line 33
def _apply(fn)
children.each do |mod|
mod._apply(fn)
end
self
end
|
#add_module(name, mod) ⇒ Object
28
29
30
31
|
# File 'lib/torch/nn/module.rb', line 28
def add_module(name, mod)
@modules[name] = mod
end
|
#apply(fn) ⇒ Object
41
42
43
44
45
46
47
|
# File 'lib/torch/nn/module.rb', line 41
def apply(fn)
children.each do |mod|
mod.apply(fn)
end
fn.call(self)
self
end
|
#buffers ⇒ Object
111
112
113
|
# File 'lib/torch/nn/module.rb', line 111
def buffers
named_buffers.values
end
|
#call(*input) ⇒ Object
82
83
84
|
# File 'lib/torch/nn/module.rb', line 82
def call(*input)
forward(*input)
end
|
#children ⇒ Object
119
120
121
|
# File 'lib/torch/nn/module.rb', line 119
def children
named_children.values
end
|
#cpu ⇒ Object
53
54
55
|
# File 'lib/torch/nn/module.rb', line 53
def cpu
_apply ->(t) { t.cpu }
end
|
#cuda(device: nil) ⇒ Object
49
50
51
|
# File 'lib/torch/nn/module.rb', line 49
def cuda(device: nil)
_apply ->(t) { t.cuda(device) }
end
|
#double ⇒ Object
65
66
67
|
# File 'lib/torch/nn/module.rb', line 65
def double
_apply ->(t) { t.floating_point? ? t.double : t }
end
|
#eval ⇒ Object
151
152
153
|
# File 'lib/torch/nn/module.rb', line 151
def eval
train(false)
end
|
#float ⇒ Object
61
62
63
|
# File 'lib/torch/nn/module.rb', line 61
def float
_apply ->(t) { t.floating_point? ? t.float : t }
end
|
#forward ⇒ Object
13
14
15
|
# File 'lib/torch/nn/module.rb', line 13
def forward
raise NotImplementedError
end
|
#half ⇒ Object
69
70
71
|
# File 'lib/torch/nn/module.rb', line 69
def half
_apply ->(t) { t.floating_point? ? t.half : t }
end
|
#inspect ⇒ Object
175
176
177
178
179
180
181
182
183
184
185
186
187
|
# File 'lib/torch/nn/module.rb', line 175
def inspect
name = self.class.name.split("::").last
if children.empty?
"#{name}(#{})"
else
str = String.new
str << "#{name}(\n"
children.each do |name, mod|
str << " (#{name}): #{mod.inspect}\n"
end
str << ")"
end
end
|
#modules ⇒ Object
135
136
137
|
# File 'lib/torch/nn/module.rb', line 135
def modules
named_modules.values
end
|
#named_buffers ⇒ Object
115
116
117
|
# File 'lib/torch/nn/module.rb', line 115
def named_buffers
@buffers || {}
end
|
#named_children ⇒ Object
123
124
125
126
127
128
129
130
131
132
133
|
# File 'lib/torch/nn/module.rb', line 123
def named_children
modules = {}
instance_variables.each do |name|
mod = instance_variable_get(name)
modules[name[1..-1]] = mod if mod.is_a?(Module)
end
@modules.each do |name, mod|
modules[name] = mod
end
modules
end
|
#named_modules ⇒ Object
139
140
141
|
# File 'lib/torch/nn/module.rb', line 139
def named_modules
{"" => self}.merge(named_children)
end
|
#named_parameters(prefix: "", recurse: true) ⇒ Object
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
|
# File 'lib/torch/nn/module.rb', line 94
def named_parameters(prefix: "", recurse: true)
params = {}
if recurse
named_children.each do |name, mod|
params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
end
end
instance_variables.each do |name|
param = instance_variable_get(name)
params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
end
@parameters.each do |name, param|
params[[prefix, name].join] = param
end
params
end
|
#parameters ⇒ Object
90
91
92
|
# File 'lib/torch/nn/module.rb', line 90
def parameters
named_parameters.values
end
|
#register_buffer(name, tensor) ⇒ Object
17
18
19
20
21
|
# File 'lib/torch/nn/module.rb', line 17
def register_buffer(name, tensor)
@buffers[name] = tensor
instance_variable_set("@#{name}", tensor)
end
|
#register_parameter(name, param) ⇒ Object
23
24
25
26
|
# File 'lib/torch/nn/module.rb', line 23
def register_parameter(name, param)
@parameters[name] = param
end
|
#requires_grad!(requires_grad: true) ⇒ Object
155
156
157
158
159
160
|
# File 'lib/torch/nn/module.rb', line 155
def requires_grad!(requires_grad: true)
parameters.each do |p|
p.requires_grad!(requires_grad)
end
self
end
|
#respond_to?(method, include_private = false) ⇒ Boolean
202
203
204
205
|
# File 'lib/torch/nn/module.rb', line 202
def respond_to?(method, include_private = false)
name = method.to_s
named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end
|
#share_memory ⇒ Object
171
172
173
|
# File 'lib/torch/nn/module.rb', line 171
def share_memory
_apply ->(t) { t.share_memory! }
end
|
#state_dict ⇒ Object
86
87
88
|
# File 'lib/torch/nn/module.rb', line 86
def state_dict
raise NotImplementedYet
end
|
#to(device) ⇒ Object
74
75
76
77
78
79
80
|
# File 'lib/torch/nn/module.rb', line 74
def to(device)
convert = lambda do |t|
t.to(device)
end
_apply(convert)
end
|
#train(mode = true) ⇒ Object
143
144
145
146
147
148
149
|
# File 'lib/torch/nn/module.rb', line 143
def train(mode = true)
@training = mode
children.each do |mod|
mod.train(mode)
end
self
end
|
#type(dst_type) ⇒ Object
57
58
59
|
# File 'lib/torch/nn/module.rb', line 57
def type(dst_type)
_apply ->(t) { t.type(dst_type) }
end
|
#zero_grad ⇒ Object
162
163
164
165
166
167
168
169
|
# File 'lib/torch/nn/module.rb', line 162
def zero_grad
parameters.each do |param|
if param.grad
param.grad.detach!
param.grad.zero!
end
end
end
|