本文对resnet-50的模型代码细节进行分析
首先导入需要的包
import torch.nn as nn
import math
import torch.utils.model_zoo as model_zoo
其次是ResNet50中的残差块Bottleneck,源码中还有对Basicblock的定义。但是Basicblock主要是针对resnet18以及resnet34的,而Bottleneck主要是针对resnet50以及更深的resnet网络的。
这里定义一个Bottleneck类,其父类是nn.Module。
class Bottleneck(nn.Module): #预定义网络架构并定义前向传播
expansion = 4 #expansion是指残差块输出维度是输入维度的多少倍。在ResNet类中的_make_layer函数中会用到
def __init__(self,in_planes,planes,stride=1,downsample=None): #初始化并继承nn.Module中的一些属性。in_planes指输入的通道数,planes指输出的通道数,步长默认为1,下采样函数默认为空(即默认不需要下采样)
super(Bottleneck,self).__init__() #定义nn.Module的子类Bottleneck。并在下面给出新的属性
self.conv1 = nn.Conv2d(in_planes,planes,kernel_size=1,bias=False)
self.bn1 = nn.BatchNorm2d(planes) #归一化处理
self.conv2 = nn.Conv2d(in_planes,planes,kernel_size=3,stride=stride,padding=1,bias=False