Class: Torch::NN::Module

Inherits:
Object
  • Object
show all
Includes:
Utils
Defined in:
lib/torch/nn/module.rb

Instance Method Summary collapse

Methods included from Utils

#_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initializeModule

Returns a new instance of Module.



6
7
8
9
10
11
# File 'lib/torch/nn/module.rb', line 6

def initialize
  @training = true
  @parameters = {}
  @buffers = {}
  @modules = {}
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method

#method_missing(method, *args, &block) ⇒ Object



239
240
241
242
243
244
245
246
247
248
249
250
# File 'lib/torch/nn/module.rb', line 239

def method_missing(method, *args, &block)
  name = method.to_s
  if named_parameters.key?(name)
    named_parameters[name]
  elsif named_buffers.key?(name)
    named_buffers[name]
  elsif named_modules.key?(name)
    named_modules[name]
  else
    super
  end
end

Instance Method Details

#_apply(fn) ⇒ Object



33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# File 'lib/torch/nn/module.rb', line 33

def _apply(fn)
  children.each do |mod|
    mod._apply(fn)
  end

  instance_variables.each do |key|
    param = instance_variable_get(key)
    if param.is_a?(Parameter)
      param_applied = nil
      Torch.no_grad do
        param_applied = fn.call(param)
      end
      # TODO should_use_set_data
      instance_variable_set(key, Parameter.new(param_applied, requires_grad: param.requires_grad))

      if param.grad
        grad_applied = nil
        Torch.no_grad do
          grad_applied = fn.call(param.grad)
        end
        # TODO should_use_set_data
        instance_variable_get(key).grad = grad_applied.requires_grad!(param.grad.requires_grad)
      end
    end
  end
  # TODO apply to more objects
  self
end

#add_module(name, mod) ⇒ Object



28
29
30
31
# File 'lib/torch/nn/module.rb', line 28

def add_module(name, mod)
  # TODO add checks
  @modules[name] = mod
end

#apply(fn) ⇒ Object



62
63
64
65
66
67
68
# File 'lib/torch/nn/module.rb', line 62

def apply(fn)
  children.each do |mod|
    mod.apply(fn)
  end
  fn.call(self)
  self
end

#buffersObject



161
162
163
# File 'lib/torch/nn/module.rb', line 161

def buffers
  named_buffers.values
end

#call(*input, **kwargs) ⇒ Object



104
105
106
# File 'lib/torch/nn/module.rb', line 104

def call(*input, **kwargs)
  forward(*input, **kwargs)
end

#childrenObject



169
170
171
# File 'lib/torch/nn/module.rb', line 169

def children
  named_children.values
end

#cpuObject



75
76
77
# File 'lib/torch/nn/module.rb', line 75

def cpu
  _apply ->(t) { t.cpu }
end

#cudaObject

TODO add device



71
72
73
# File 'lib/torch/nn/module.rb', line 71

def cuda
  _apply ->(t) { t.cuda }
end

#doubleObject



87
88
89
# File 'lib/torch/nn/module.rb', line 87

def double
  _apply ->(t) { t.floating_point? ? t.double : t }
end

#evalObject



201
202
203
# File 'lib/torch/nn/module.rb', line 201

def eval
  train(false)
end

#floatObject



83
84
85
# File 'lib/torch/nn/module.rb', line 83

def float
  _apply ->(t) { t.floating_point? ? t.float : t }
end

#forwardObject

Raises:

  • (NotImplementedError)


13
14
15
# File 'lib/torch/nn/module.rb', line 13

def forward
  raise NotImplementedError
end

#halfObject



91
92
93
# File 'lib/torch/nn/module.rb', line 91

def half
  _apply ->(t) { t.floating_point? ? t.half : t }
end

#inspectObject



225
226
227
228
229
230
231
232
233
234
235
236
237
# File 'lib/torch/nn/module.rb', line 225

def inspect
  name = self.class.name.split("::").last
  if children.empty?
    "#{name}(#{extra_inspect})"
  else
    str = String.new
    str << "#{name}(\n"
    children.each do |name, mod|
      str << "  (#{name}): #{mod.inspect}\n"
    end
    str << ")"
  end
end

#load_state_dict(state_dict) ⇒ Object

TODO add strict option TODO match PyTorch behavior



118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# File 'lib/torch/nn/module.rb', line 118

def load_state_dict(state_dict)
  state_dict.each do |k, input_param|
    k1, k2 = k.split(".", 2)
    mod = named_modules[k1]
    if mod.is_a?(Module)
      param = mod.named_parameters[k2]
      if param.is_a?(Parameter)
        Torch.no_grad do
          param.copy!(input_param)
        end
      else
        raise Error, "Unknown parameter: #{k1}"
      end
    else
      raise Error, "Unknown module: #{k1}"
    end
  end

  # TODO return missing keys and unexpected keys
  nil
end

#modulesObject



185
186
187
# File 'lib/torch/nn/module.rb', line 185

def modules
  named_modules.values
end

#named_buffersObject



165
166
167
# File 'lib/torch/nn/module.rb', line 165

def named_buffers
  @buffers || {}
end

#named_childrenObject



173
174
175
176
177
178
179
180
181
182
183
# File 'lib/torch/nn/module.rb', line 173

def named_children
  modules = {}
  instance_variables.each do |name|
    mod = instance_variable_get(name)
    modules[name[1..-1]] = mod if mod.is_a?(Module)
  end
  @modules.each do |name, mod|
    modules[name] = mod
  end
  modules
end

#named_modulesObject



189
190
191
# File 'lib/torch/nn/module.rb', line 189

def named_modules
  {"" => self}.merge(named_children)
end

#named_parameters(prefix: "", recurse: true) ⇒ Object



144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# File 'lib/torch/nn/module.rb', line 144

def named_parameters(prefix: "", recurse: true)
  params = {}
  if recurse
    named_children.each do |name, mod|
      params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
    end
  end
  instance_variables.each do |name|
    param = instance_variable_get(name)
    params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
  end
  @parameters.each do |name, param|
    params[[prefix, name].join] = param if param
  end
  params
end

#parametersObject



140
141
142
# File 'lib/torch/nn/module.rb', line 140

def parameters
  named_parameters.values
end

#register_buffer(name, tensor) ⇒ Object



17
18
19
20
21
# File 'lib/torch/nn/module.rb', line 17

def register_buffer(name, tensor)
  # TODO add checks
  @buffers[name] = tensor
  instance_variable_set("@#{name}", tensor)
end

#register_parameter(name, param) ⇒ Object



23
24
25
26
# File 'lib/torch/nn/module.rb', line 23

def register_parameter(name, param)
  # TODO add checks
  @parameters[name] = param
end

#requires_grad!(requires_grad: true) ⇒ Object



205
206
207
208
209
210
# File 'lib/torch/nn/module.rb', line 205

def requires_grad!(requires_grad: true)
  parameters.each do |p|
    p.requires_grad!(requires_grad)
  end
  self
end

#respond_to?(method, include_private = false) ⇒ Boolean

Returns:

  • (Boolean)


252
253
254
255
# File 'lib/torch/nn/module.rb', line 252

def respond_to?(method, include_private = false)
  name = method.to_s
  named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end

#share_memoryObject



221
222
223
# File 'lib/torch/nn/module.rb', line 221

def share_memory
  _apply ->(t) { t.share_memory! }
end

#state_dict(destination: nil) ⇒ Object



108
109
110
111
112
113
114
# File 'lib/torch/nn/module.rb', line 108

def state_dict(destination: nil)
  destination ||= {}
  named_parameters.each do |k, v|
    destination[k] = v
  end
  destination
end

#to(device) ⇒ Object

modifies in-place



96
97
98
99
100
101
102
# File 'lib/torch/nn/module.rb', line 96

def to(device)
  convert = lambda do |t|
    t.to(device)
  end

  _apply(convert)
end

#train(mode = true) ⇒ Object



193
194
195
196
197
198
199
# File 'lib/torch/nn/module.rb', line 193

def train(mode = true)
  @training = mode
  children.each do |mod|
    mod.train(mode)
  end
  self
end

#type(dst_type) ⇒ Object



79
80
81
# File 'lib/torch/nn/module.rb', line 79

def type(dst_type)
  _apply ->(t) { t.type(dst_type) }
end

#zero_gradObject



212
213
214
215
216
217
218
219
# File 'lib/torch/nn/module.rb', line 212

def zero_grad
  parameters.each do |param|
    if param.grad
      param.grad.detach!
      param.grad.zero!
    end
  end
end