Class: TorchRec::Modules::CrossNet::CrossNet

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchrec/modules/cross_net/cross_net.rb

Instance Method Summary collapse

Constructor Details

#initialize(in_features, num_layers) ⇒ CrossNet

Returns a new instance of CrossNet.



5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# File 'lib/torchrec/modules/cross_net/cross_net.rb', line 5

def initialize(in_features, num_layers)
  super()
  @num_layers = num_layers
  @kernels = Torch::NN::ParameterList.new(
    @num_layers.times.map do |i|
      Torch::NN::Parameter.new(
        Torch::NN::Init.xavier_normal!(Torch.empty(in_features, in_features))
      )
    end
  )
  @bias = Torch::NN::ParameterList.new(
    @num_layers.times.map do |i|
      Torch::NN::Parameter.new(Torch::NN::Init.zeros!(Torch.empty(in_features, 1)))
    end
  )
end

Instance Method Details

#forward(input) ⇒ Object



22
23
24
25
26
27
28
29
30
31
32
# File 'lib/torchrec/modules/cross_net/cross_net.rb', line 22

def forward(input)
  x_0 = input.unsqueeze(2)  # (B, N, 1)
  x_l = x_0

  @num_layers.times do |layer|
    xl_w = Torch.matmul(@kernels[layer], x_l)  # (B, N, 1)
    x_l = x_0 * (xl_w + @bias[layer]) + x_l  # (B, N, 1)
  end

  Torch.squeeze(x_l, dim: 2)
end