Class: Torch::NN::Module
  
  
  
  
  
    - Inherits:
- 
      Object
      
        
        show all
      
    
      - Includes:
- Utils
    - Defined in:
- lib/torch/nn/module.rb
 
  Direct Known Subclasses
  AdaptiveAvgPoolNd, AdaptiveMaxPoolNd, AvgPoolNd, BatchNorm, Bilinear, ConstantPadNd, ConvNd, CosineSimilarity, DropoutNd, Embedding, EmbeddingBag, Fold, GroupNorm, Hardshrink, Identity, LPPoolNd, LayerNorm, LeakyReLU, Linear, LocalResponseNorm, LogSigmoid, LogSoftmax, Loss, MaxPoolNd, MaxUnpoolNd, ModuleList, MultiheadAttention, PReLU, PairwiseDistance, RNNBase, ReLU, ReflectionPadNd, ReplicationPadNd, Sequential, Sigmoid, Softmax, Softmax2d, Softmin, Softplus, Softshrink, Softsign, Tanh, Tanhshrink, Transformer, TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer, Unfold, Upsample
 
  Instance Attribute Summary collapse
  
  
    
      Instance Method Summary
      collapse
    
    
      
        - 
  
    
      #_apply(fn)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #add_module(name, mod)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #apply(fn)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #buffers  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #call(*input, **kwargs)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #children  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #cpu  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #cuda  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #deep_dup  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #double  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #eval  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #float  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #forward  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #half  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #initialize  ⇒ Module 
    
    
  
  
  
    constructor
  
  
  
  
  
  
  
    
A new instance of Module. 
 
- 
  
    
      #inspect  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #load_state_dict(state_dict, strict: true)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #method_missing(method, *args, &block)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #modules  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #named_buffers  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #named_children  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #named_modules(memo: nil, prefix: "")  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #named_parameters(prefix: "", recurse: true)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #parameters  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #register_buffer(name, tensor)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #register_parameter(name, param)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #requires_grad!(requires_grad: true)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #respond_to?(method, include_private = false)  ⇒ Boolean 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #share_memory  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #state_dict(destination: nil, prefix: "")  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #to(device)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #train(mode = true)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #type(dst_type)  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
- 
  
    
      #zero_grad  ⇒ Object 
    
    
  
  
  
  
  
  
  
  
  
    
  
Methods included from Utils
  #_activation_fn, #_clones, #_ntuple, #_pair, #_quadrupal, #_single, #_triple
  Constructor Details
  
    
  
  
    #initialize  ⇒ Module 
  
  
  
  
    
Returns a new instance of Module.
   
 
  
  
    | 
8
9
10
11
12
13 | # File 'lib/torch/nn/module.rb', line 8
def initialize
  @training = true
  @parameters = {}
  @buffers = {}
  @modules = {}
end | 
 
  
 
  Dynamic Method Handling
  
    This class handles dynamic methods through the method_missing method
    
  
  
    
  
  
    #method_missing(method, *args, &block)  ⇒ Object 
  
  
  
  
    | 
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305 | # File 'lib/torch/nn/module.rb', line 288
def method_missing(method, *args, &block)
  name = method.to_s
  if named_parameters.key?(name)
    named_parameters[name]
  elsif named_buffers.key?(name)
    named_buffers[name]
  elsif named_modules.key?(name)
    named_modules[name]
  elsif method.end_with?("=") && named_modules.key?(method[0..-2])
    if instance_variable_defined?("@#{method[0..-2]}")
      instance_variable_set("@#{method[0..-2]}", *args)
    else
      raise NotImplementedYet
    end
  else
    super
  end
end | 
 
  
 
  
    Instance Attribute Details
    
      
      
      
  
  
    #training  ⇒ Object  
  
  
  
  
    
Returns the value of attribute training.
   
 
  
  
    | 
6
7
8 | # File 'lib/torch/nn/module.rb', line 6
def training
  @training
end | 
 
    
   
  
    Instance Method Details
    
      
  
  
    #_apply(fn)  ⇒ Object 
  
  
  
  
    | 
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70 | # File 'lib/torch/nn/module.rb', line 35
def _apply(fn)
  children.each do |mod|
    mod._apply(fn)
  end
  instance_variables.each do |key|
    param = instance_variable_get(key)
    if param.is_a?(Parameter)
      param_applied = nil
      Torch.no_grad do
        param_applied = fn.call(param)
      end
            instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))
      if param.grad
        grad_applied = nil
        Torch.no_grad do
          grad_applied = fn.call(param.grad)
        end
                instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
      end
    end
  end
  @buffers.each_key do |k|
    buf = @buffers[k]
    unless buf.nil?
      @buffers[k] = fn.call(buf)
      instance_variable_set("@#{k}", @buffers[k])
    end
  end
  self
end | 
 
    
      
  
  
    #add_module(name, mod)  ⇒ Object 
  
  
  
  
    | 
30
31
32
33 | # File 'lib/torch/nn/module.rb', line 30
def add_module(name, mod)
    @modules[name] = mod
end | 
 
    
      
  
  
    #apply(fn)  ⇒ Object 
  
  
  
  
    | 
72
73
74
75
76
77
78 | # File 'lib/torch/nn/module.rb', line 72
def apply(fn)
  children.each do |mod|
    mod.apply(fn)
  end
  fn.call(self)
  self
end | 
 
    
      
  
  
    #buffers  ⇒ Object 
  
  
  
  
    | 
189
190
191 | # File 'lib/torch/nn/module.rb', line 189
def buffers
  named_buffers.values
end | 
 
    
      
  
  
    #call(*input, **kwargs)  ⇒ Object 
  
  
  
  
    | 
114
115
116 | # File 'lib/torch/nn/module.rb', line 114
def call(*input, **kwargs)
  forward(*input, **kwargs)
end | 
 
    
      
  
  
    #children  ⇒ Object 
  
  
  
  
    | 
197
198
199 | # File 'lib/torch/nn/module.rb', line 197
def children
  named_children.values
end | 
 
    
      
  
  
    #cpu  ⇒ Object 
  
  
  
  
    | 
85
86
87 | # File 'lib/torch/nn/module.rb', line 85
def cpu
  _apply ->(t) { t.cpu }
end | 
 
    
      
  
  
    #cuda  ⇒ Object 
  
  
  
  
  
    | 
81
82
83 | # File 'lib/torch/nn/module.rb', line 81
def cuda
  _apply ->(t) { t.cuda }
end | 
 
    
      
  
  
    #deep_dup  ⇒ Object 
  
  
  
  
    | 
283
284
285
286 | # File 'lib/torch/nn/module.rb', line 283
def deep_dup
  memo = {}
  dup_value(self, memo)
end | 
 
    
      
  
  
    #double  ⇒ Object 
  
  
  
  
    | 
97
98
99 | # File 'lib/torch/nn/module.rb', line 97
def double
  _apply ->(t) { t.floating_point? ? t.double : t }
end | 
 
    
      
  
  
    #eval  ⇒ Object 
  
  
  
  
    | 
243
244
245 | # File 'lib/torch/nn/module.rb', line 243
def eval
  train(false)
end | 
 
    
      
  
  
    #float  ⇒ Object 
  
  
  
  
    | 
93
94
95 | # File 'lib/torch/nn/module.rb', line 93
def float
  _apply ->(t) { t.floating_point? ? t.float : t }
end | 
 
    
      
  
  
    #forward  ⇒ Object 
  
  
  
  
    | 
15
16
17 | # File 'lib/torch/nn/module.rb', line 15
def forward
  raise NotImplementedError
end | 
 
    
      
  
  
    #half  ⇒ Object 
  
  
  
  
    | 
101
102
103 | # File 'lib/torch/nn/module.rb', line 101
def half
  _apply ->(t) { t.floating_point? ? t.half : t }
end | 
 
    
      
  
  
    #inspect  ⇒ Object 
  
  
  
  
    | 
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281 | # File 'lib/torch/nn/module.rb', line 267
def inspect
  name = self.class.name.split("::").last
  if named_children.empty?
    "#{name}(#{})"
  else
    str = String.new
    str << "#{name}(\n"
    named_children.each do |name, mod|
      mod_str = mod.inspect
      mod_str = mod_str.lines.join("  ")
      str << "  (#{name}): #{mod_str}\n"
    end
    str << ")"
  end
end | 
 
    
      
  
  
    #load_state_dict(state_dict, strict: true)  ⇒ Object 
  
  
  
  
    | 
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166 | # File 'lib/torch/nn/module.rb', line 129
def load_state_dict(state_dict, strict: true)
    raise "strict: false not implemented yet" unless strict
  missing_keys = []
  unexpected_keys = []
  error_msgs = []
  
  _load = lambda do |mod, prefix = ""|
        local_metadata = {}
    mod.send(:load_from_state_dict, state_dict, prefix, local_metadata, true, missing_keys, unexpected_keys, error_msgs)
    mod.named_children.each do |name, child|
      _load.call(child, prefix + name + ".") unless child.nil?
    end
  end
  _load.call(self)
  if strict
    if unexpected_keys.any?
      error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
    end
    if missing_keys.any?
      error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
    end
  end
  if error_msgs.any?
        raise Error, error_msgs[0]
  end
  nil
end | 
 
    
      
  
  
    #modules  ⇒ Object 
  
  
  
  
    | 
213
214
215 | # File 'lib/torch/nn/module.rb', line 213
def modules
  named_modules.values
end | 
 
    
      
  
  
    #named_buffers  ⇒ Object 
  
  
  
  
    | 
193
194
195 | # File 'lib/torch/nn/module.rb', line 193
def named_buffers
  @buffers || {}
end | 
 
    
      
  
  
    #named_children  ⇒ Object 
  
  
  
  
    | 
201
202
203
204
205
206
207
208
209
210
211 | # File 'lib/torch/nn/module.rb', line 201
def named_children
  modules = {}
  instance_variables.each do |name|
    mod = instance_variable_get(name)
    modules[name[1..-1]] = mod if mod.is_a?(Module)
  end
  @modules.each do |name, mod|
    modules[name] = mod
  end
  modules
end | 
 
    
      
  
  
    #named_modules(memo: nil, prefix: "")  ⇒ Object 
  
  
  
  
  
    | 
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233 | # File 'lib/torch/nn/module.rb', line 218
def named_modules(memo: nil, prefix: "")
  ret = {}
  memo ||= Set.new
  unless memo.include?(self)
    memo << self
    ret[prefix] = self
    named_children.each do |name, mod|
      next unless mod.is_a?(Module)
      submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
      mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
        ret[m[0]] = m[1]
      end
    end
  end
  ret
end | 
 
    
      
  
  
    #named_parameters(prefix: "", recurse: true)  ⇒ Object 
  
  
  
  
    | 
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187 | # File 'lib/torch/nn/module.rb', line 172
def named_parameters(prefix: "", recurse: true)
  params = {}
  if recurse
    named_children.each do |name, mod|
      params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
    end
  end
  instance_variables.each do |name|
    param = instance_variable_get(name)
    params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
  end
  @parameters.each do |name, param|
    params[[prefix, name].join] = param if param
  end
  params
end | 
 
    
      
  
  
    #parameters  ⇒ Object 
  
  
  
  
    | 
168
169
170 | # File 'lib/torch/nn/module.rb', line 168
def parameters
  named_parameters.values
end | 
 
    
      
  
  
    #register_buffer(name, tensor)  ⇒ Object 
  
  
  
  
    | 
19
20
21
22
23 | # File 'lib/torch/nn/module.rb', line 19
def register_buffer(name, tensor)
    @buffers[name] = tensor
  instance_variable_set("@#{name}", tensor)
end | 
 
    
      
  
  
    #register_parameter(name, param)  ⇒ Object 
  
  
  
  
    | 
25
26
27
28 | # File 'lib/torch/nn/module.rb', line 25
def register_parameter(name, param)
    @parameters[name] = param
end | 
 
    
      
  
  
    #requires_grad!(requires_grad: true)  ⇒ Object 
  
  
  
  
    | 
247
248
249
250
251
252 | # File 'lib/torch/nn/module.rb', line 247
def requires_grad!(requires_grad: true)
  parameters.each do |p|
    p.requires_grad!(requires_grad)
  end
  self
end | 
 
    
      
  
  
    #respond_to?(method, include_private = false)  ⇒ Boolean 
  
  
  
  
    | 
307
308
309
310 | # File 'lib/torch/nn/module.rb', line 307
def respond_to?(method, include_private = false)
  name = method.to_s
  named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end | 
 
    
      
  
  
    #share_memory  ⇒ Object 
  
  
  
  
    | 
263
264
265 | # File 'lib/torch/nn/module.rb', line 263
def share_memory
  _apply ->(t) { t.share_memory! }
end | 
 
    
      
  
  
    #state_dict(destination: nil, prefix: "")  ⇒ Object 
  
  
  
  
    | 
118
119
120
121
122
123
124
125
126
127 | # File 'lib/torch/nn/module.rb', line 118
def state_dict(destination: nil, prefix: "")
  destination ||= {}
  save_to_state_dict(destination, prefix: prefix)
  named_children.each do |name, mod|
    next unless mod
    mod.state_dict(destination: destination, prefix: prefix + name + ".")
  end
  destination
end | 
 
    
      
  
  
    #to(device)  ⇒ Object 
  
  
  
  
  
    | 
106
107
108
109
110
111
112 | # File 'lib/torch/nn/module.rb', line 106
def to(device)
  convert = lambda do |t|
    t.to(device)
  end
  _apply(convert)
end | 
 
    
      
  
  
    #train(mode = true)  ⇒ Object 
  
  
  
  
    | 
235
236
237
238
239
240
241 | # File 'lib/torch/nn/module.rb', line 235
def train(mode = true)
  @training = mode
  children.each do |mod|
    mod.train(mode)
  end
  self
end | 
 
    
      
  
  
    #type(dst_type)  ⇒ Object 
  
  
  
  
    | 
89
90
91 | # File 'lib/torch/nn/module.rb', line 89
def type(dst_type)
  _apply ->(t) { t.type(dst_type) }
end | 
 
    
      
  
  
    #zero_grad  ⇒ Object 
  
  
  
  
    | 
254
255
256
257
258
259
260
261 | # File 'lib/torch/nn/module.rb', line 254
def zero_grad
  parameters.each do |param|
    if param.grad
      param.grad.detach!
      param.grad.zero!
    end
  end
end |