Class: TorchVision::Models::AlexNet

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchvision/models/alexnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(num_classes: 1000) ⇒ AlexNet

Returns a new instance of AlexNet.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# File 'lib/torchvision/models/alexnet.rb', line 4

def initialize(num_classes: 1000)
  super()
  @features = Torch::NN::Sequential.new(
    Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::MaxPool2d.new(3, stride: 2),
    Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::MaxPool2d.new(3, stride: 2),
    Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::MaxPool2d.new(3, stride: 2),
  )
  @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
  @classifier = Torch::NN::Sequential.new(
    Torch::NN::Dropout.new,
    Torch::NN::Linear.new(256 * 6 * 6, 4096),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::Dropout.new,
    Torch::NN::Linear.new(4096, 4096),
    Torch::NN::ReLU.new(inplace: true),
    Torch::NN::Linear.new(4096, num_classes)
  )
end

Instance Method Details

#forward(x) ⇒ Object



33
34
35
36
37
38
39
# File 'lib/torchvision/models/alexnet.rb', line 33

def forward(x)
  x = @features.call(x)
  x = @avgpool.call(x)
  x = Torch.flatten(x, 1)
  x = @classifier.call(x)
  x
end