Class: Torch::NN::Module

Inherits:
Object
  • Object
show all
Includes:
Utils
Defined in:
lib/torch/nn/module.rb

Instance Method Summary collapse

Methods included from Utils

#_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initializeModule

Returns a new instance of Module.



6
7
8
9
10
11
# File 'lib/torch/nn/module.rb', line 6

def initialize
  @training = true
  @parameters = {}
  @buffers = {}
  @modules = {}
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method

#method_missing(method, *args, &block) ⇒ Object



281
282
283
284
285
286
287
288
289
290
291
292
# File 'lib/torch/nn/module.rb', line 281

def method_missing(method, *args, &block)
  name = method.to_s
  if named_parameters.key?(name)
    named_parameters[name]
  elsif named_buffers.key?(name)
    named_buffers[name]
  elsif named_modules.key?(name)
    named_modules[name]
  else
    super
  end
end

Instance Method Details

#_apply(fn) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# File 'lib/torch/nn/module.rb', line 33

def _apply(fn)
  children.each do |mod|
    mod._apply(fn)
  end

  instance_variables.each do |key|
    param = instance_variable_get(key)
    if param.is_a?(Parameter)
      param_applied = nil
      Torch.no_grad do
        param_applied = fn.call(param)
      end
      # TODO should_use_set_data
      instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))

      if param.grad
        grad_applied = nil
        Torch.no_grad do
          grad_applied = fn.call(param.grad)
        end
        # TODO should_use_set_data
        instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
      end
    end
  end

  @buffers.each_key do |k|
    buf = @buffers[k]
    unless buf.nil?
      @buffers[k] = fn.call(buf)
      instance_variable_set("@#{k}", @buffers[k])
    end
  end

  self
end

#add_module(name, mod) ⇒ Object



28
29
30
31
# File 'lib/torch/nn/module.rb', line 28

def add_module(name, mod)
  # TODO add checks
  @modules[name] = mod
end

#apply(fn) ⇒ Object



70
71
72
73
74
75
76
# File 'lib/torch/nn/module.rb', line 70

def apply(fn)
  children.each do |mod|
    mod.apply(fn)
  end
  fn.call(self)
  self
end

#buffersObject



187
188
189
# File 'lib/torch/nn/module.rb', line 187

def buffers
  named_buffers.values
end

#call(*input, **kwargs) ⇒ Object



112
113
114
# File 'lib/torch/nn/module.rb', line 112

def call(*input, **kwargs)
  forward(*input, **kwargs)
end

#childrenObject



195
196
197
# File 'lib/torch/nn/module.rb', line 195

def children
  named_children.values
end

#cpuObject



83
84
85
# File 'lib/torch/nn/module.rb', line 83

def cpu
  _apply ->(t) { t.cpu }
end

#cudaObject

TODO add device



79
80
81
# File 'lib/torch/nn/module.rb', line 79

def cuda
  _apply ->(t) { t.cuda }
end

#doubleObject



95
96
97
# File 'lib/torch/nn/module.rb', line 95

def double
  _apply ->(t) { t.floating_point? ? t.double : t }
end

#evalObject



241
242
243
# File 'lib/torch/nn/module.rb', line 241

def eval
  train(false)
end

#floatObject



91
92
93
# File 'lib/torch/nn/module.rb', line 91

def float
  _apply ->(t) { t.floating_point? ? t.float : t }
end

#forwardObject

Raises:

  • (NotImplementedError)


13
14
15
# File 'lib/torch/nn/module.rb', line 13

def forward
  raise NotImplementedError
end

#halfObject



99
100
101
# File 'lib/torch/nn/module.rb', line 99

def half
  _apply ->(t) { t.floating_point? ? t.half : t }
end

#inspectObject



265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# File 'lib/torch/nn/module.rb', line 265

def inspect
  name = self.class.name.split("::").last
  if named_children.empty?
    "#{name}(#{extra_inspect})"
  else
    str = String.new
    str << "#{name}(\n"
    named_children.each do |name, mod|
      mod_str = mod.inspect
      mod_str = mod_str.lines.join("  ")
      str << "  (#{name}): #{mod_str}\n"
    end
    str << ")"
  end
end

#load_state_dict(state_dict, strict: true) ⇒ Object



127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
# File 'lib/torch/nn/module.rb', line 127

def load_state_dict(state_dict, strict: true)
  # TODO support strict: false
  raise "strict: false not implemented yet" unless strict

  missing_keys = []
  unexpected_keys = []
  error_msgs = []

  # TODO handle metadata

  _load = lambda do |mod, prefix = ""|
    # TODO handle metadata
     = {}
    mod.send(:load_from_state_dict, state_dict, prefix, , true, missing_keys, unexpected_keys, error_msgs)
    mod.named_children.each do |name, child|
      _load.call(child, prefix + name + ".") unless child.nil?
    end
  end

  _load.call(self)

  if strict
    if unexpected_keys.any?
      error_msgs << "Unexpected key(s) in state_dict: #{unexpected_keys.join(", ")}"
    end

    if missing_keys.any?
      error_msgs << "Missing key(s) in state_dict: #{missing_keys.join(", ")}"
    end
  end

  if error_msgs.any?
    # just show first error
    raise Error, error_msgs[0]
  end

  nil
end

#modulesObject



211
212
213
# File 'lib/torch/nn/module.rb', line 211

def modules
  named_modules.values
end

#named_buffersObject



191
192
193
# File 'lib/torch/nn/module.rb', line 191

def named_buffers
  @buffers || {}
end

#named_childrenObject



199
200
201
202
203
204
205
206
207
208
209
# File 'lib/torch/nn/module.rb', line 199

def named_children
  modules = {}
  instance_variables.each do |name|
    mod = instance_variable_get(name)
    modules[name[1..-1]] = mod if mod.is_a?(Module)
  end
  @modules.each do |name, mod|
    modules[name] = mod
  end
  modules
end

#named_modules(memo: nil, prefix: "") ⇒ Object

TODO return enumerator?



216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# File 'lib/torch/nn/module.rb', line 216

def named_modules(memo: nil, prefix: "")
  ret = {}
  memo ||= Set.new
  unless memo.include?(self)
    memo << self
    ret[prefix] = self
    named_children.each do |name, mod|
      next unless mod.is_a?(Module)
      submodule_prefix = prefix + (!prefix.empty? ? "." : "") + name
      mod.named_modules(memo: memo, prefix: submodule_prefix).each do |m|
        ret[m[0]] = m[1]
      end
    end
  end
  ret
end

#named_parameters(prefix: "", recurse: true) ⇒ Object



170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# File 'lib/torch/nn/module.rb', line 170

def named_parameters(prefix: "", recurse: true)
  params = {}
  if recurse
    named_children.each do |name, mod|
      params.merge!(mod.named_parameters(prefix: "#{prefix}#{name}.", recurse: recurse))
    end
  end
  instance_variables.each do |name|
    param = instance_variable_get(name)
    params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
  end
  @parameters.each do |name, param|
    params[[prefix, name].join] = param if param
  end
  params
end

#parametersObject



166
167
168
# File 'lib/torch/nn/module.rb', line 166

def parameters
  named_parameters.values
end

#register_buffer(name, tensor) ⇒ Object



17
18
19
20
21
# File 'lib/torch/nn/module.rb', line 17

def register_buffer(name, tensor)
  # TODO add checks
  @buffers[name] = tensor
  instance_variable_set("@#{name}", tensor)
end

#register_parameter(name, param) ⇒ Object



23
24
25
26
# File 'lib/torch/nn/module.rb', line 23

def register_parameter(name, param)
  # TODO add checks
  @parameters[name] = param
end

#requires_grad!(requires_grad: true) ⇒ Object



245
246
247
248
249
250
# File 'lib/torch/nn/module.rb', line 245

def requires_grad!(requires_grad: true)
  parameters.each do |p|
    p.requires_grad!(requires_grad)
  end
  self
end

#respond_to?(method, include_private = false) ⇒ Boolean

Returns:

  • (Boolean)


294
295
296
297
# File 'lib/torch/nn/module.rb', line 294

def respond_to?(method, include_private = false)
  name = method.to_s
  named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end

#share_memoryObject



261
262
263
# File 'lib/torch/nn/module.rb', line 261

def share_memory
  _apply ->(t) { t.share_memory! }
end

#state_dict(destination: nil, prefix: "") ⇒ Object



116
117
118
119
120
121
122
123
124
125
# File 'lib/torch/nn/module.rb', line 116

def state_dict(destination: nil, prefix: "")
  destination ||= {}
  save_to_state_dict(destination, prefix: prefix)

  named_children.each do |name, mod|
    next unless mod
    mod.state_dict(destination: destination, prefix: prefix + name + ".")
  end
  destination
end

#to(device) ⇒ Object

modifies in-place



104
105
106
107
108
109
110
# File 'lib/torch/nn/module.rb', line 104

def to(device)
  convert = lambda do |t|
    t.to(device)
  end

  _apply(convert)
end

#train(mode = true) ⇒ Object



233
234
235
236
237
238
239
# File 'lib/torch/nn/module.rb', line 233

def train(mode = true)
  @training = mode
  children.each do |mod|
    mod.train(mode)
  end
  self
end

#type(dst_type) ⇒ Object



87
88
89
# File 'lib/torch/nn/module.rb', line 87

def type(dst_type)
  _apply ->(t) { t.type(dst_type) }
end

#zero_gradObject



252
253
254
255
256
257
258
259
# File 'lib/torch/nn/module.rb', line 252

def zero_grad
  parameters.each do |param|
    if param.grad
      param.grad.detach!
      param.grad.zero!
    end
  end
end