Class: TorchRec::Modules::CrossNet::CrossNet
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- TorchRec::Modules::CrossNet::CrossNet
- Defined in:
- lib/torchrec/modules/cross_net/cross_net.rb
Instance Method Summary collapse
- #forward(input) ⇒ Object
-
#initialize(in_features, num_layers) ⇒ CrossNet
constructor
A new instance of CrossNet.
Constructor Details
#initialize(in_features, num_layers) ⇒ CrossNet
Returns a new instance of CrossNet.
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
# File 'lib/torchrec/modules/cross_net/cross_net.rb', line 5 def initialize(in_features, num_layers) super() @num_layers = num_layers @kernels = Torch::NN::ParameterList.new( @num_layers.times.map do |i| Torch::NN::Parameter.new( Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features)) ) end ) @bias = Torch::NN::ParameterList.new( @num_layers.times.map do |i| Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1))) end ) end |
Instance Method Details
#forward(input) ⇒ Object
22 23 24 25 26 27 28 29 30 31 32 |
# File 'lib/torchrec/modules/cross_net/cross_net.rb', line 22 def forward(input) x_0 = input.unsqueeze(2) # (B, N, 1) x_l = x_0 @num_layers.times do |layer| xl_w = Torch.matmul(@kernels[layer], x_l) # (B, N, 1) x_l = x_0 * (xl_w + @bias[layer]) + x_l # (B, N, 1) end Torch.squeeze(x_l, dim: 2) end |