Class: TorchRec::Modules::MLP::MLP

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchrec/modules/mlp/mlp.rb

Instance Method Summary collapse

Constructor Details

#initialize(in_size, layer_sizes, bias: true, activation: :relu, device: nil) ⇒ MLP

Returns a new instance of MLP.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# File 'lib/torchrec/modules/mlp/mlp.rb', line 5

def initialize(in_size, layer_sizes, bias: true, activation: :relu, device: nil)
  super()

  if activation == :relu
    activation = Torch.method(:relu)
  elsif activation == :sigmoid
    activation = Torch.method(:sigmoid)
  end

  if !activation.is_a?(Symbol)
    @mlp = Torch::NN::Sequential.new(
      *layer_sizes.length.times.map do |i|
        Perceptron.new(
          i > 0 ? layer_sizes[i - 1] : in_size,
          layer_sizes[i],
          bias: bias,
          activation: Utils.extract_module_or_tensor_callable(activation),
          device: device
        )
      end
    )
  else
    if activation == :swish_layernorm
      @mlp = Torch::NN::Sequential.new(
        *layer_sizes.length.times.map do |i|
          Perceptron.new(
            i > 0 ? layer_sizes[i - 1] : in_size,
            layer_sizes[i],
            bias: bias,
            activation: Activation::SwishLayerNorm.new(layer_sizes[i], device: device),
            device: device
          )
        end
      )
    else
      raise ArgumentError, "This MLP only supports activation function of :relu, :sigmoid, and :swish_layernorm"
    end
  end
end

Instance Method Details

#forward(input) ⇒ Object



45
46
47
# File 'lib/torchrec/modules/mlp/mlp.rb', line 45

def forward(input)
  @mlp.call(input)
end