Class: TorchVision::Models::ResNet

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchvision/models/resnet.rb

Instance Method Summary collapse

Constructor Details

#initialize(block, layers, num_classes = 1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) ⇒ ResNet

Returns a new instance of ResNet.



4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# File 'lib/torchvision/models/resnet.rb', line 4

def initialize(block, layers, num_classes=1000, zero_init_residual: false,
  groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)

  super()
  norm_layer ||= Torch::NN::BatchNorm2d
  @norm_layer = norm_layer

  @inplanes = 64
  @dilation = 1
  if replace_stride_with_dilation.nil?
    # each element in the tuple indicates if we should replace
    # the 2x2 stride with a dilated convolution instead
    replace_stride_with_dilation = [false, false, false]
  end
  if replace_stride_with_dilation.length != 3
    raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
  end
  @groups = groups
  @base_width = width_per_group
  @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
  @bn1 = norm_layer.new(@inplanes)
  @relu = Torch::NN::ReLU.new(inplace: true)
  @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
  @layer1 = _make_layer(block, 64, layers[0])
  @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
  @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
  @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
  @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
  @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)

  modules.each do |m|
    case m
    when Torch::NN::Conv2d
      Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
    when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
      Torch::NN::Init.constant!(m.weight, 1)
      Torch::NN::Init.constant!(m.bias, 0)
    end
  end

  # Zero-initialize the last BN in each residual branch,
  # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  if zero_init_residual
    modules.each do |m|
      case m
      when Bottleneck
        Torch::NN::Init.constant!(m.bn3.weight, 0)
      when BasicBlock
        Torch::NN::Init.constant!(m.bn2.weight, 0)
      end
    end
  end
end

Instance Method Details

#_forward_impl(x) ⇒ Object



84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# File 'lib/torchvision/models/resnet.rb', line 84

def _forward_impl(x)
  x = @conv1.call(x)
  x = @bn1.call(x)
  x = @relu.call(x)
  x = @maxpool.call(x)

  x = @layer1.call(x)
  x = @layer2.call(x)
  x = @layer3.call(x)
  x = @layer4.call(x)

  x = @avgpool.call(x)
  x = Torch.flatten(x, 1)
  x = @fc.call(x)

  x
end

#_make_layer(block, planes, blocks, stride: 1, dilate: false) ⇒ Object



59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# File 'lib/torchvision/models/resnet.rb', line 59

def _make_layer(block, planes, blocks, stride: 1, dilate: false)
  norm_layer = @norm_layer
  downsample = nil
  previous_dilation = @dilation
  if dilate
    @dilation *= stride
    stride = 1
  end
  if stride != 1 || @inplanes != planes * block.expansion
    downsample = Torch::NN::Sequential.new(
      Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
      norm_layer.new(planes * block.expansion)
    )
  end

  layers = []
  layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
  @inplanes = planes * block.expansion
  (blocks - 1).times do
    layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
  end

  Torch::NN::Sequential.new(*layers)
end

#forward(x) ⇒ Object



102
103
104
# File 'lib/torchvision/models/resnet.rb', line 102

def forward(x)
  _forward_impl(x)
end