Class: Torch::NN::Module

Inherits:
Object
  • Object
show all
Includes:
Utils
Defined in:
lib/torch/nn/module.rb

Instance Method Summary collapse

Methods included from Utils

#_ntuple, #_pair, #_quadrupal, #_single, #_triple

Constructor Details

#initializeModule

Returns a new instance of Module.



6
7
8
9
10
11
# File 'lib/torch/nn/module.rb', line 6

def initialize
  @training = true
  @parameters = {}
  @buffers = {}
  @modules = {}
end

Dynamic Method Handling

This class handles dynamic methods through the method_missing method

#method_missing(method, *args, &block) ⇒ Object



189
190
191
192
193
194
195
196
197
198
199
200
# File 'lib/torch/nn/module.rb', line 189

def method_missing(method, *args, &block)
  name = method.to_s
  if named_parameters.key?(name)
    named_parameters[name]
  elsif named_buffers.key?(name)
    named_buffers[name]
  elsif named_modules.key?(name)
    named_modules[name]
  else
    super
  end
end

Instance Method Details

#_apply(fn) ⇒ Object



33
34
35
36
37
38
39
# File 'lib/torch/nn/module.rb', line 33

def _apply(fn)
  children.each do |mod|
    mod._apply(fn)
  end
  # TODO apply to more objects
  self
end

#add_module(name, mod) ⇒ Object



28
29
30
31
# File 'lib/torch/nn/module.rb', line 28

def add_module(name, mod)
  # TODO add checks
  @modules[name] = mod
end

#apply(fn) ⇒ Object



41
42
43
44
45
46
47
# File 'lib/torch/nn/module.rb', line 41

def apply(fn)
  children.each do |mod|
    mod.apply(fn)
  end
  fn.call(self)
  self
end

#buffersObject



111
112
113
# File 'lib/torch/nn/module.rb', line 111

def buffers
  named_buffers.values
end

#call(*input) ⇒ Object



82
83
84
# File 'lib/torch/nn/module.rb', line 82

def call(*input)
  forward(*input)
end

#childrenObject



119
120
121
# File 'lib/torch/nn/module.rb', line 119

def children
  named_children.values
end

#cpuObject



53
54
55
# File 'lib/torch/nn/module.rb', line 53

def cpu
  _apply ->(t) { t.cpu }
end

#cuda(device: nil) ⇒ Object



49
50
51
# File 'lib/torch/nn/module.rb', line 49

def cuda(device: nil)
  _apply ->(t) { t.cuda(device) }
end

#doubleObject



65
66
67
# File 'lib/torch/nn/module.rb', line 65

def double
  _apply ->(t) { t.floating_point? ? t.double : t }
end

#evalObject



151
152
153
# File 'lib/torch/nn/module.rb', line 151

def eval
  train(false)
end

#floatObject



61
62
63
# File 'lib/torch/nn/module.rb', line 61

def float
  _apply ->(t) { t.floating_point? ? t.float : t }
end

#forwardObject

Raises:

  • (NotImplementedError)


13
14
15
# File 'lib/torch/nn/module.rb', line 13

def forward
  raise NotImplementedError
end

#halfObject



69
70
71
# File 'lib/torch/nn/module.rb', line 69

def half
  _apply ->(t) { t.floating_point? ? t.half : t }
end

#inspectObject



175
176
177
178
179
180
181
182
183
184
185
186
187
# File 'lib/torch/nn/module.rb', line 175

def inspect
  name = self.class.name.split("::").last
  if children.empty?
    "#{name}(#{extra_inspect})"
  else
    str = String.new
    str << "#{name}(\n"
    children.each do |name, mod|
      str << "  (#{name}): #{mod.inspect}\n"
    end
    str << ")"
  end
end

#modulesObject



135
136
137
# File 'lib/torch/nn/module.rb', line 135

def modules
  named_modules.values
end

#named_buffersObject



115
116
117
# File 'lib/torch/nn/module.rb', line 115

def named_buffers
  @buffers || {}
end

#named_childrenObject



123
124
125
126
127
128
129
130
131
132
133
# File 'lib/torch/nn/module.rb', line 123

def named_children
  modules = {}
  instance_variables.each do |name|
    mod = instance_variable_get(name)
    modules[name[1..-1]] = mod if mod.is_a?(Module)
  end
  @modules.each do |name, mod|
    modules[name] = mod
  end
  modules
end

#named_modulesObject



139
140
141
# File 'lib/torch/nn/module.rb', line 139

def named_modules
  {"" => self}.merge(named_children)
end

#named_parameters(prefix: "", recurse: true) ⇒ Object



94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# File 'lib/torch/nn/module.rb', line 94

def named_parameters(prefix: "", recurse: true)
  params = {}
  if recurse
    named_children.each do |name, mod|
      params.merge!(mod.named_parameters(prefix: "#{name}.", recurse: recurse))
    end
  end
  instance_variables.each do |name|
    param = instance_variable_get(name)
    params[[prefix, name[1..-1]].join] = param if param.is_a?(Parameter)
  end
  @parameters.each do |name, param|
    params[[prefix, name].join] = param
  end
  params
end

#parametersObject



90
91
92
# File 'lib/torch/nn/module.rb', line 90

def parameters
  named_parameters.values
end

#register_buffer(name, tensor) ⇒ Object



17
18
19
20
21
# File 'lib/torch/nn/module.rb', line 17

def register_buffer(name, tensor)
  # TODO add checks
  @buffers[name] = tensor
  instance_variable_set("@#{name}", tensor)
end

#register_parameter(name, param) ⇒ Object



23
24
25
26
# File 'lib/torch/nn/module.rb', line 23

def register_parameter(name, param)
  # TODO add checks
  @parameters[name] = param
end

#requires_grad!(requires_grad: true) ⇒ Object



155
156
157
158
159
160
# File 'lib/torch/nn/module.rb', line 155

def requires_grad!(requires_grad: true)
  parameters.each do |p|
    p.requires_grad!(requires_grad)
  end
  self
end

#respond_to?(method, include_private = false) ⇒ Boolean

Returns:

  • (Boolean)


202
203
204
205
# File 'lib/torch/nn/module.rb', line 202

def respond_to?(method, include_private = false)
  name = method.to_s
  named_parameters.key?(name) || named_buffers.key?(name) || named_modules.key?(name) || super
end

#share_memoryObject



171
172
173
# File 'lib/torch/nn/module.rb', line 171

def share_memory
  _apply ->(t) { t.share_memory! }
end

#state_dictObject

Raises:



86
87
88
# File 'lib/torch/nn/module.rb', line 86

def state_dict
  raise NotImplementedYet
end

#to(device) ⇒ Object

modifies in-place



74
75
76
77
78
79
80
# File 'lib/torch/nn/module.rb', line 74

def to(device)
  convert = lambda do |t|
    t.to(device)
  end

  _apply(convert)
end

#train(mode = true) ⇒ Object



143
144
145
146
147
148
149
# File 'lib/torch/nn/module.rb', line 143

def train(mode = true)
  @training = mode
  children.each do |mod|
    mod.train(mode)
  end
  self
end

#type(dst_type) ⇒ Object



57
58
59
# File 'lib/torch/nn/module.rb', line 57

def type(dst_type)
  _apply ->(t) { t.type(dst_type) }
end

#zero_gradObject



162
163
164
165
166
167
168
169
# File 'lib/torch/nn/module.rb', line 162

def zero_grad
  parameters.each do |param|
    if param.grad
      param.grad.detach!
      param.grad.zero!
    end
  end
end