最近在学习PyTorch,想着光看不动手不行,就尝试着对着torchvision.models下的ResNet实现写了一遍,顺带将ResNet复习了一下,并将在研读源码过程中自己做的一点笔记和理解简单的分享一下。
torchvision.models.resnet下有ResNet18、ResNet34、ResNet50、ResNet101、ResNet152五种结构的ResNet实现。
主要的区别在于:
- 18层和34层的ResNet用到的结构是"两个3*3的conv层堆叠"的BasicBlock,而50,101,152层的ResNet用到了Bottleneck结构,所以源码中也分别有BasicBlock和Bottleneck的实现。
- 相同的BasicBlock/Bottleneck重复堆叠的次数不同。
贴一张ResNet的结构图可能会更清晰一些:
从而主体核心代码ResNet类的初始化形式即为:
# block:BasicBlock or Bottleneck
# layers: 根据自己需要搭积木。ResNet50的设置为layers=[3, 4, 6, 3]
# num_classes:根据自己需要设置最终fc层输出的分类数目,默认是ImageNet的1000类
def __init__(self, block, layers, num_classes=1000):
所以接下来主要重点解读Bottleneck类和ResNet类的实现。
一、Bottleneck类源码解读
同样,先贴一张Bottleneck的结构图(图片来自网络,依据tensorflow画的,PyTorch的image Tensor应该是(N,C,H,W)),图和代码相结合更有助于理解。
class Bottleneck(nn.Module):
'''B