本篇文章以代码为主要展现形式,穿插与代码部分的注释对各关键部位进行了解释
ResNet是由Kaiming He为主的几位研究员提出的具有重要意义的模型结构。
ResNet的出现对于 “神经网络层数达到一定数量后,越深的神经网络反而会降低模型精度” 这个问题提出了一种可能的解决方案——即通过“shortcut”进行层次短接:shortcut保证低层次模型得到良好的训练,同时给予模型再次提升的能力,即若在原低层模型以外存在进一步提升模型能力的可能则进行尝试并提升。
1. 结构分析
上图所示即为原模型中的基本结构(ResNetBlock),一个block中有两个卷积层以及一个直连shortcut。输出为经过block的
F
(
x
)
F(x)
F(x)加上直连输入
x
x
x,则有
o
u
t
p
u
t
=
F
(
x
)
+
x
output=F(x)+x
output=F(x)+x. 根据基本结构可知以下限制:经过block的
F
(
x
)
F(x)
F(x) 必须与输出
x
x
x 维度信息保持一致,在之后的代码部分会有注释
2.ResNet18实现
本篇文章选择18层ResNet进行分析与展示,其他层次模型结构大致相同(下面贴出一张原论文中的结构列表)。
ResNet18由17个卷积层以及1个全连接构成。以下为代码展示:
import torch
import torch.nn as nn
from torch.nn import functional as F
import numpy as np
# 构建ResNet基本块
class ResNetBlock(nn.Module):
def __init__(self, ch_in, ch_out, stride=1):
"""
:param ch_in: 输入维度
:param ch_out: 输出维度
"""
super(ResNetBlock, self).__init__()
self.CV_1 = nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=stride, padding=1)
self.BN_1 = nn.BatchNorm2d(ch_out)
self.CV_2 = nn.Conv2d(ch_out, ch_out, kernel_size=3, stride=1, padding=1)
self.BN_2 = nn.BatchNorm2d(ch_out)
# 定义直连层
self.EX = nn.Sequential()
if ch_in != ch_out:
# 如果经过block与直连的维度不匹配,那么就无法进行直接相加,需要进行调整
# 使用kernel_size=1不改变尺度信息,stride=stride使得尺寸信息一直
self.EX = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=1, stride=stride)
)
def forward(self, x):
# 经过block内容的结果
out_1 = F.relu(self.BN_1(self.CV_1(x)))
out_1 = self.BN_2(self.CV_2(out_1))
# 未经过block的结果为,EX(input)
out_2 = self.EX(x)
# 最终输出结果
return out_1 + out_2
# 自定义一个展平模块
class Flatten(nn.Module):
def __init__(self):
super(Flatten, self).__init__()
def forward(self, input):
print(input.size(0), input.size(1), input.size(2), input.size(3))
return input.view(input.size(0), -1)
# 定义ResNet18
class ResNet18(nn.Module):
def __init__(self, ch_in, ch_out):
"""
ResNet18由17个卷积层+1个全连接层构成
:param ch_in: 输入维度
:param ch_out: 输出维度
"""
super(ResNet18, self).__init__()
self.module = nn.Sequential(
# 1,[b,3,32,32]=>[b,64,16,16]
nn.Conv2d(ch_in, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
# # 2,3,4,5层;[b,64,16,16]=>[b,128,8,8]
ResNetBlock(64, 64, stride=1),
ResNetBlock(64, 128, stride=2),
# # 6,7,8,9层;[b,128,8,8]=>[b,256,4,4]
ResNetBlock(128, 128, stride=1),
ResNetBlock(128, 256, stride=2),
# # 10,11,12,13层;[b,256,4,4]=>[b,512,2,2]
ResNetBlock(256, 256, stride=1),
ResNetBlock(256, 512, stride=2),
# # 14,15,16,17层;[b,512,2,2]=>[b,512,1,1]
ResNetBlock(512, 512, stride=1),
ResNetBlock(512, 1024, stride=2),
# [b,1024,1,1]=>[b,1024*1*1]
Flatten(),
# 18层, 全连接层;[b,1024]=>[b,ch_out]
nn.Linear(1024, ch_out)
)
def forward(self, x):
out = self.module(x)
return out
def main():
# 随机生成输入数据
input = torch.randn(10, 3, 32, 32)
# 定义ResNet,3表示输出ch,10表示最终分类类别
net = ResNet18(3, 10)
out = net.forward(input)
# [10,10]
print(out.shape)
if __name__ == '__main__':
main()
注意:本文中的实现基于cifar-10数据集,可以直接应用于对应数据集的训练
3.小结
以上即为本篇文章的全部内容
转载请注明出处,感谢!