Class: TorchVision::Models::ResNet
- Inherits:
-
Torch::NN::Module
- Object
- Torch::NN::Module
- TorchVision::Models::ResNet
- Defined in:
- lib/torchvision/models/resnet.rb
Instance Method Summary collapse
- #_forward_impl(x) ⇒ Object
- #_make_layer(block, planes, blocks, stride: 1, dilate: false) ⇒ Object
- #forward(x) ⇒ Object
-
#initialize(block, layers, num_classes = 1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) ⇒ ResNet
constructor
A new instance of ResNet.
Constructor Details
#initialize(block, layers, num_classes = 1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) ⇒ ResNet
Returns a new instance of ResNet.
4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# File 'lib/torchvision/models/resnet.rb', line 4 def initialize(block, layers, num_classes=1000, zero_init_residual: false, groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil) super() norm_layer ||= Torch::NN::BatchNorm2d @norm_layer = norm_layer @inplanes = 64 @dilation = 1 if replace_stride_with_dilation.nil? # each element in the tuple indicates if we should replace # the 2x2 stride with a dilated convolution instead replace_stride_with_dilation = [false, false, false] end if replace_stride_with_dilation.length != 3 raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}" end @groups = groups @base_width = width_per_group @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false) @bn1 = norm_layer.new(@inplanes) @relu = Torch::NN::ReLU.new(inplace: true) @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1) @layer1 = _make_layer(block, 64, layers[0]) @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0]) @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1]) @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2]) @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1]) @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes) modules.each do |m| case m when Torch::NN::Conv2d Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu") when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm Torch::NN::Init.constant!(m.weight, 1) Torch::NN::Init.constant!(m.bias, 0) end end # Zero-initialize the last BN in each residual branch, # so that the residual branch starts with zeros, and each residual block behaves like an identity. # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 if zero_init_residual modules.each do |m| case m when Bottleneck Torch::NN::Init.constant!(m.bn3.weight, 0) when BasicBlock Torch::NN::Init.constant!(m.bn2.weight, 0) end end end end |
Instance Method Details
#_forward_impl(x) ⇒ Object
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
# File 'lib/torchvision/models/resnet.rb', line 84 def _forward_impl(x) x = @conv1.call(x) x = @bn1.call(x) x = @relu.call(x) x = @maxpool.call(x) x = @layer1.call(x) x = @layer2.call(x) x = @layer3.call(x) x = @layer4.call(x) x = @avgpool.call(x) x = Torch.flatten(x, 1) x = @fc.call(x) x end |
#_make_layer(block, planes, blocks, stride: 1, dilate: false) ⇒ Object
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# File 'lib/torchvision/models/resnet.rb', line 59 def _make_layer(block, planes, blocks, stride: 1, dilate: false) norm_layer = @norm_layer downsample = nil previous_dilation = @dilation if dilate @dilation *= stride stride = 1 end if stride != 1 || @inplanes != planes * block.expansion downsample = Torch::NN::Sequential.new( Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false), norm_layer.new(planes * block.expansion) ) end layers = [] layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer) @inplanes = planes * block.expansion (blocks - 1).times do layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer) end Torch::NN::Sequential.new(*layers) end |
#forward(x) ⇒ Object
102 103 104 |
# File 'lib/torchvision/models/resnet.rb', line 102 def forward(x) _forward_impl(x) end |