Class: TorchVision::Models::ResNet

Inherits:
Torch::NN::Module
  • Object
show all
Defined in:
lib/torchvision/models/resnet.rb

Constant Summary collapse

MODEL_URLS =
{
  "resnet18" => "https://download.pytorch.org/models/resnet18-f37072fd.pth",
  "resnet34" => "https://download.pytorch.org/models/resnet34-b627a593.pth",
  "resnet50" => "https://download.pytorch.org/models/resnet50-0676ba61.pth",
  "resnet101" => "https://download.pytorch.org/models/resnet101-63fe2227.pth",
  "resnet152" => "https://download.pytorch.org/models/resnet152-394f9c45.pth",
  "resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
  "resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
  "wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
  "wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
}

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(block, layers, num_classes = 1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) ⇒ ResNet



16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# File 'lib/torchvision/models/resnet.rb', line 16

def initialize(
  block,
  layers,
  num_classes = 1000,
  zero_init_residual: false,
  groups: 1,
  width_per_group: 64,
  replace_stride_with_dilation: nil,
  norm_layer: nil
)
  super()
  norm_layer ||= Torch::NN::BatchNorm2d
  @norm_layer = norm_layer

  @inplanes = 64
  @dilation = 1
  if replace_stride_with_dilation.nil?
    # each element in the tuple indicates if we should replace
    # the 2x2 stride with a dilated convolution instead
    replace_stride_with_dilation = [false, false, false]
  end
  if replace_stride_with_dilation.length != 3
    raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
  end
  @groups = groups
  @base_width = width_per_group
  @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
  @bn1 = norm_layer.new(@inplanes)
  @relu = Torch::NN::ReLU.new(inplace: true)
  @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
  @layer1 = _make_layer(block, 64, layers[0])
  @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
  @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
  @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
  @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
  @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)

  modules.each do |m|
    case m
    when Torch::NN::Conv2d
      Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
    when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
      Torch::NN::Init.constant!(m.weight, 1)
      Torch::NN::Init.constant!(m.bias, 0)
    end
  end

  # Zero-initialize the last BN in each residual branch,
  # so that the residual branch starts with zeros, and each residual block behaves like an identity.
  # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
  if zero_init_residual
    modules.each do |m|
      case m
      when Bottleneck
        Torch::NN::Init.constant!(m.bn3.weight, 0)
      when BasicBlock
        Torch::NN::Init.constant!(m.bn2.weight, 0)
      end
    end
  end
end

Class Method Details

.make_model(arch, block, layers, pretrained: false, **kwargs) ⇒ Object



125
126
127
128
129
130
131
132
133
# File 'lib/torchvision/models/resnet.rb', line 125

def self.make_model(arch, block, layers, pretrained: false, **kwargs)
  model = ResNet.new(block, layers, **kwargs)
  if pretrained
    url = MODEL_URLS[arch]
    state_dict = Torch::Hub.load_state_dict_from_url(url)
    model.load_state_dict(state_dict)
  end
  model
end

Instance Method Details

#_forward_impl(x) ⇒ Object



103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# File 'lib/torchvision/models/resnet.rb', line 103

def _forward_impl(x)
  x = @conv1.call(x)
  x = @bn1.call(x)
  x = @relu.call(x)
  x = @maxpool.call(x)

  x = @layer1.call(x)
  x = @layer2.call(x)
  x = @layer3.call(x)
  x = @layer4.call(x)

  x = @avgpool.call(x)
  x = Torch.flatten(x, 1)
  x = @fc.call(x)

  x
end

#_make_layer(block, planes, blocks, stride: 1, dilate: false) ⇒ Object



78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# File 'lib/torchvision/models/resnet.rb', line 78

def _make_layer(block, planes, blocks, stride: 1, dilate: false)
  norm_layer = @norm_layer
  downsample = nil
  previous_dilation = @dilation
  if dilate
    @dilation *= stride
    stride = 1
  end
  if stride != 1 || @inplanes != planes * block.expansion
    downsample = Torch::NN::Sequential.new(
      Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
      norm_layer.new(planes * block.expansion)
    )
  end

  layers = []
  layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
  @inplanes = planes * block.expansion
  (blocks - 1).times do
    layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
  end

  Torch::NN::Sequential.new(*layers)
end

#forward(x) ⇒ Object



121
122
123
# File 'lib/torchvision/models/resnet.rb', line 121

def forward(x)
  _forward_impl(x)
end