Class: TorchRec::Modules::MLP::MLP
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- TorchRec::Modules::MLP::MLP
- Defined in:
- lib/torchrec/modules/mlp/mlp.rb
Instance Method Summary collapse
- #forward(input) ⇒ Object
-
#initialize(in_size, layer_sizes, bias: true, activation: :relu, device: nil) ⇒ MLP
constructor
A new instance of MLP.
Constructor Details
#initialize(in_size, layer_sizes, bias: true, activation: :relu, device: nil) ⇒ MLP
Returns a new instance of MLP.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
# File 'lib/torchrec/modules/mlp/mlp.rb', line 5 def initialize(in_size, layer_sizes, bias: true, activation: :relu, device: nil) super() if activation == :relu activation = Torch.method(:relu) elsif activation == :sigmoid activation = Torch.method(:sigmoid) end if !activation.is_a?(Symbol) @mlp = Torch::NN::Sequential.new( *layer_sizes.length.times.map do |i| Perceptron.new( i > 0 ? layer_sizes[i - 1] : in_size, layer_sizes[i], bias: bias, activation: Utils.extract_module_or_tensor_callable(activation), device: device ) end ) else if activation == :swish_layernorm @mlp = Torch::NN::Sequential.new( *layer_sizes.length.times.map do |i| Perceptron.new( i > 0 ? layer_sizes[i - 1] : in_size, layer_sizes[i], bias: bias, activation: Activation::SwishLayerNorm.new(layer_sizes[i], device: device), device: device ) end ) else raise ArgumentError, "This MLP only supports activation function of :relu, :sigmoid, and :swish_layernorm" end end end |
Instance Method Details
#forward(input) ⇒ Object
45 46 47 |
# File 'lib/torchrec/modules/mlp/mlp.rb', line 45 def forward(input) @mlp.call(input) end |