前言
本学习笔记参考自B站up主霹雳吧啦Wz
代码均来自导师github开源项目WZMIAOMIAO/deep-learning-for-image-processing: deep learning for image processing including classification and object-detection etc. (github.com)
一、VGG网络结构
vgg有多种网络结构
![](https://s2.loli.net/2022/05/30/AnmGkIqQ2Ja1STl.png)
其中常用的是VGG16的网络结构
![](https://s2.loli.net/2022/05/30/SfuXDrP8CkY3bzw.png)
网络中的亮点:通过堆叠多个3×3得到卷积核来代替大尺度卷积核,以减少所需的参数。论文中提到,可以通过堆叠两个3×3的卷积核代替5×5的卷积核,堆叠三个3×3的卷积核代替7×7的卷积核
拥有相同的感受野
1.1、感受野的计算
什么是感受野:
输出矩阵上的一个单元对应的输入层上的区域大小
例如此处,第三层一个特征对应第一层的区域是5×5,所以说第三层的感受野就是5×5
![](https://s2.loli.net/2022/05/30/4cvkNzQLZn5xlg6.png)
计算公式:
F
(
i
)
=
(
F
(
i
+
1
)
−
1
)
×
S
t
r
i
d
e
+
K
s
i
z
e
F(i)=(F(i+1)-1)\times Stride + Ksize
F(i)=(F(i+1)−1)×Stride+Ksize
![](https://s2.loli.net/2022/05/30/1DB2d3TmjMNanGu.png)
再回到之前说的用多个3×3的卷积核代替大卷积核
![](https://s2.loli.net/2022/05/30/KZrFhkUbMnmf5p3.png)
可以很清楚的看到参数少了不少
1、model
这里的搭建与之前有所不同的是,因为VGG有多种网络配置,所以我们建立一个make_features
的函数来生成不同的网络结构
def make_features(cfg:list): # 接受参数为列表
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
layers += [conv2d, nn.ReLU(True)]
in_channels = v
return nn.Sequential(*layers) # 返回这里用到了一个可变参数(*layers),具体用法可以自行查找
然后我们在定义一个字典cfgs来储存特征层参数
其中数字代表这一层是卷积层,数字为卷积核的个数。
“M”为这一层是最大下采样池化层
cfgs = {
'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}
最后我们只需要在初始化函数里面加上变量features
来接收我们的特征就可以了
def __init__(self, features, num_classes=1000, init_weights=False):
同时我们还需要搭建一个接收函数来接收选择的网络层命令
def vgg(model_name='vgg16', **kwargs): # 使用关键字参数**kwargs
assert model_name in cfgs, "Warning: model number {} not in cfgs dict!".format(model_name)
cfg = cfgs[model_name]
model = VGG(make_features(cfg), **kwargs)
return model
2、train and predict
训练模块和预测模块与上一张alexnet的也就大同小异了,这里就不讲了
最后在感叹一下,霹导师的代码写的真漂亮!