一、框架表
二、代码定义
1.定义18,34层残差结构
class BasicBlock(nn.Module):
expansion = 1 #定义主分支残差卷积核个数有没有发生变化
def __init__(self, in_channel, out_channel, stride = 1, downsample = None):
super(BasicBlock,self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel,out_channels=out_channel,
kernel_size=3,stride=stride,padding=1,bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel,out_channels=out_channel,
kernel_size=3,stride=1,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self,x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
代码解释
- 数据经过
7x7,64,stride 2
之后的规格就变成了112x112x64
,再经过3x3 max pool,stride 2
之后就变成了56x56x64(padding=1)
- 接下来就接入上述的残差结构。以34层为例,先使用1个三层残差结构,再使用1个四层残差结构。这两个结构输出的通道数不同。
- 第一个3层的残差结构的输入特征图规格即为输出特征图规格,因此可以直接相加。第二个4层的输入特征图为128通道,因此必须先经过一个1x1的卷积,将它的通道数扩大,才能和输出的特征层相加。这就是上述代码中的
downsample
。
2.定义50,101,152层残差结构
class Bottleneck(nn.Module):
expansion = 4
def __init__(self,in_channel,out_channel,stride = 1,downsample = None):
super(Bottleneck,self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel,out_channels=out_channel,
kernel_size=1,stride=1,bias=False)
self.bn1 = nn.BatchNorm2d(out_channel)
self.conv2 = nn.Conv2d(in_channels=out_channel,out_channels=out_channel,
kernel_size=3,stride=stride,padding=1,bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.conv3 = nn.Conv2d(in_channels=out_channel,out_channels=out_channel*self.expansion,
kernel_size=1,stride=1,bias=False)
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self,x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
代码说明
- 原论文中只有实线的残差结构,所以我根据
conv3_x
的50层画出了虚线的残差结构。 - 数据经过
7x7,64,stride 2
之后的规格就变成了112x112x64
,再经过3x3 max pool,stride 2
之后就变成了56x56x64(padding=1)
- 在50层的
conv2_x
第一次开始阶段就是实线的残差结构(下左图),在conv2_x到conv3_x的过渡就用到了虚线的残差结构(下右图)。
3.定义ResNet网络
接下来定义主网络
class ResNet(nn.Module):
def __init__(self,block,blocks_num,num_classes=1000,include_top = True):
super(ResNet,self).__init__()
self.include_top = include_top
self.in_channel = 64 #通过max pool之后得到的深度,都是64
self.conv1 = nn.Conv2d(3,self.in_channel,kernel_size=7,
stride=2,padding=3,bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3,stride=2,padding=1)
self.layer1 = self._make_layer(block,64,blocks_num[0])
self.layer2 = self._make_layer(block,128,blocks_num[1],stride=2)
self.layer3 = self._make_layer(block,256,blocks_num[2],stride=2)
self.layer4 = self._make_layer(block,512,blocks_num[3],stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1,1)) #output size = (1,1)
self.fc = nn.Linear(512*block.expansion,num_classes)
for m in self.modules():
if isinstance(m,nn.Conv2d):
nn.init.kaiming_normal_(m.weight,mode='fan_out',nonlinearity='relu')
def _make_layer(self,block,channel,block_num,stride = 1):
downsample = None
if stride != 1 or self.in_channel != channel*block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel,channel*block.expansion,kernel_size=1,stride=stride,bias=False),
nn.BatchNorm2d(channel*block.expansion)
)
layers = []
layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride))
self.in_channel = channel*block.expansion
for _ in range(1,block_num):
layers.append(block(self.in_channel,channel))
return nn.Sequential(*layers)
def forward(self,x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x,1)
x = self.fc(x)
return x
4. 函数调用
def resnet34(num_classes = 1000,include_top = True):
return ResNet(BasicBlock,[3,4,6,3],num_classes=num_classes,include_top=include_top)
def resnet101(num_classes=1000,include_top = True):
return ResNet(Bottleneck,[3,4,23,3],num_classes=num_classes,include_top=include_top)