Class: TorchRec::Modules::Activation::SwishLayerNorm

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchrec/modules/activation/swish_layer_norm.rb

Instance Method Summary collapse

Constructor Details

#initialize(input_dims, device: nil) ⇒ SwishLayerNorm

Returns a new instance of SwishLayerNorm.



5
6
7
8
9
10
11
12
# File 'lib/torchrec/modules/activation/swish_layer_norm.rb', line 5

def initialize(input_dims, device: nil)
  super()
  @norm = Torch::NN::Sequential.new(
    # TODO add device
    Torch::NN::LayerNorm.new(input_dims), #, device: device),
    Torch::NN::Sigmoid.new
  )
end

Instance Method Details

#forward(input) ⇒ Object



14
15
16
# File 'lib/torchrec/modules/activation/swish_layer_norm.rb', line 14

def forward(input)
  input * @norm.call(input)
end