【写在前面】:大家好,我是【猪葛】
一个很看好AI前景的算法工程师
在接下来的系列博客里面我会持续更新Keras的教学内容(文末有大纲)
内容主要分为两部分
第一部分是Keras的基础知识
第二部分是使用Keras搭建FasterCNN、YOLO目标检测神经网络
代码复用性高
如果你也感兴趣,欢迎关注我的动态一起学习
学习建议:
有些内容一开始学起来有点蒙,对照着“学习目标”去学习即可
一步一个脚印,走到山顶再往下看一切风景就全明了了
本篇博客学习目标:1、理解GooletNet网络结构;2、学会自己搭建GooletNet网络结构
文章目录
一、函数式API搭建GooletNet卷积神经网络
在搭建GooletNet卷积神经网络之前,我们要先弄明白这个网络的具体架构是什么样子的
1-1、 GooletNet简介
GoogLeNet卷积神经网络出自于《Going deeper with convolutions》这篇论文,是由谷歌公司Christian Szegedy、Yangqing Jia等人联合发表。其研究成果在2014年 ILSVRC 挑战赛 ImageNet 分类任务上获得冠军,而当时的亚军就是上一篇文中讲到的VGG系列。
很有意思的是,GoogLeNet名字是由Google的前缀Goog与LeNet的组合而来,这其实是对Yann LeCuns开拓性的LeNet-5网络的致敬。
GoogLeNet卷积神经网络的最大贡献在于,提出了非常经典的Inception模块。该网络结构的最大特点是网络内部计算资源的利用率很高。因此该设计允许在保持计算资源预算不变的情况下增加网络的深度和宽度,使得GoogLeNet网络层数达到了更深的22层,但是网络参数仅为AlexNet的1/12。
1-2、网络组成
1-2-1、Inception模块
模块的网络结构如图所示:
inception模块的网络结构有两种形式,论文中采取的是第二种形式。
提示一点:
- 结构图中把卷积操作的卷积核大小都写清楚了,但是没有说明步长和
padding
,翻看论文可知,这里所有卷积步长都为1
,padding都为'same'
, 激活函数都选择'relu'
;最大池化层的步长也为1
,padding
也为'same'
;
因为这个代码模块经常用到,所以我将其封装成一个函数
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Concatenate, AveragePooling1D
from tensorflow.keras.layers import Dense, Flatten, Dropout, AveragePooling2D, Softmax
from tensorflow.keras.models import Model
def inception(x, filters):
"""
inception模块
:param x: 输入张量
:param filters: 一个列表,元素表示卷积核的个数,int类型
:return: 输出张量
"""
# 对应结构图从左边计算的第一条分支
branch1 = Conv2D(filters[0], (1, 1), 1, 'same', activation='relu')(x)
# 对应结构图从左边计算的第二条分支
branch2 = Conv2D(filters[1], (1, 1), 1, 'same', activation='relu')(x)
branch2 = Conv2D(filters[2], (3, 3), 1, 'same', activation='relu')(branch2)
# 对应结构图从左边计算的第三条分支
branch3 = Conv2D(filters[3], (1, 1), 1, 'same', activation='relu')(x)
branch3 = Conv2D(filters[4], (5, 5), 1, 'same', activation='relu')(branch3)
# 对应结构图从左边计算的第四条分支
branch4 = MaxPooling2D((3, 3), 1, 'same')(x)
branch4 = Conv2D(filters[5], (1, 1), 1, 'same')(branch4)
# 将每个分支的最后一维堆叠起来,对应结构图中的是“Filter concatenate”
y = Concatenate()([branch1, branch2, branch3, branch4])
return y
1-2-2、辅助分类器
考虑到网络的深度较大,以有效方式将梯度传播回所有层的能力有限,可能会产生梯度弥散现象。因此在网络中间层设计了两个辅助分类器,希望以此激励网络在较低层进行分类,从而增加了被传播回的梯度信号,避免出现梯度弥散。
在训练过程中,它们的损失将以折扣权重添加到网络的总损失中(辅助分类器的损失加权为0.3)。 在测试过程中,这些辅助网络将被丢弃。其网络结构如下:
提示下面四点:
- 第一层是
AveragePooling2D
平均池化层,大小是(5, 5), 步长是1,padding
是'valid'
- 第二层是卷积层,卷积核个数是128,大小是(1, 1), 步长是1,
padding
是'same'
- 第三层、第四层都为全连接层,节点分别为1024、1000,激活函数都选择
'relu'
- 最后一层是
softmax
激活层
因为这个辅助分类器要使用到两次,所以我也将其封装成一个函数
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Concatenate, AveragePooling1D
from tensorflow.keras.layers import Dense, Flatten, Dropout, AveragePooling2D, Softmax
from tensorflow.keras.models import Model
def y_branch_loss(x):
"""
辅助分类器
:param x: 输入张量
:return: 输出张量
"""
# 每一行代码代表一层网络
y_branch = AveragePooling2D((5, 5), 3)(x)
y_branch = Conv2D(128, (1, 1), 1, 'same', activation='relu')(y_branch)
y_branch = Dense(1024, activation='relu')(y_branch)
y_branch = Dense(1000, activation='relu')(y_branch)
y_branch = Softmax()(y_branch)
return y_branch
1-2-3、GooletNet整体架构
清楚了上面两小节说的内容,我们再来看总体的网络架构就很简单了,先看个完整架构图:
待会我们就根据这个网络架构图和下面的参数列表(原论文给出、如下)来搭建GoogLeNet整体结构
提示两点:
- ‘# 3 x 3 reduce’ 表示 ‘# 3 x 3’ 前面的 ‘# 1 x 1’(理解不了请再看一遍1-2-1小节内容)
- ‘# 5 x 5 reduce’ 表示 ‘# 5 x 5’ 前面的 ‘# 1 x 1’
根据以上所有内容,我们使用Keras函数式API搭建模型的方法搭建出来的GooletNet卷积神经网络如下:
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, BatchNormalization, Concatenate, AveragePooling1D
from tensorflow.keras.layers import Dense, Flatten, Dropout, AveragePooling2D, Softmax
from tensorflow.keras.models import Model
def inception(x, filters):
"""
inception模块
:param x: 输入张量
:param filters: 一个列表,元素表示卷积核的个数,int类型
:return: 输出张量
"""
# 对应结构图从左边计算的第一条分支
branch1 = Conv2D(filters[0], (1, 1), 1, 'same', activation='relu')(x)
# 对应结构图从左边计算的第二条分支
branch2 = Conv2D(filters[1], (1, 1), 1, 'same', activation='relu')(x)
branch2 = Conv2D(filters[2], (3, 3), 1, 'same', activation='relu')(branch2)
# 对应结构图从左边计算的第三条分支
branch3 = Conv2D(filters[3], (1, 1), 1, 'same', activation='relu')(x)
branch3 = Conv2D(filters[4], (5, 5), 1, 'same', activation='relu')(branch3)
# 对应结构图从左边计算的第四条分支
branch4 = MaxPooling2D((3, 3), 1, 'same')(x)
branch4 = Conv2D(filters[5], (1, 1), 1, 'same')(branch4)
# 将每个分支的最后一维堆叠起来,对应结构图中的是“Filter concatenate”
y = Concatenate()([branch1, branch2, branch3, branch4])
return y
def y_branch_loss(x):
"""
辅助分类器
:param x: 输入张量
:return: 输出张量
"""
# 每一行代码代表一层网络
y_branch = AveragePooling2D((5, 5), 3)(x)
y_branch = Conv2D(128, (1, 1), 1, 'same', activation='relu')(y_branch)
y_branch = Dense(1024, activation='relu')(y_branch)
y_branch = Dense(1000, activation='relu')(y_branch)
y_branch = Softmax()(y_branch)
return y_branch
input_images = Input(shape=(224, 224, 3))
x = Conv2D(64, (7, 7), 2, 'same', activation='relu')(input_images)
x = MaxPooling2D((3, 3), 2, 'same')(x)
x = BatchNormalization()(x)
x = Conv2D(64, (1, 1), 1, 'same', activation='relu')(x)
x = Conv2D(192, (3, 3), 1, 'same', activation='relu')(x)
x = BatchNormalization()(x)
x = MaxPooling2D((3, 3), 2, 'same')(x)
x = inception(x, [64, 96, 128, 16, 32, 32])
x = inception(x, [128, 128, 192, 32, 96, 64])
x = MaxPooling2D((3, 3), 2, 'same')(x)
inception_4a = inception(x, [192, 96, 208, 16, 48, 64])
x = inception(inception_4a, [160, 112, 224, 24, 64, 64])
x = inception(x, [128, 128, 256, 24, 64, 64])
inception_4d = inception(x, [112, 144, 288, 32, 64, 64])
x = inception(inception_4d, [256, 160, 320, 32, 128, 128])
x = MaxPooling2D((3, 3), 2, 'same')(x)
x = inception(x, [256, 160, 320, 32, 128, 128])
x = inception(x, [384, 192, 384, 48, 128, 128])
x = AveragePooling2D((7, 7), 1)(x)
x = Dropout(0.4)(x)
x = Flatten()(x)
x = Dense(1000, activation='softmax')(x)
y_main = Softmax()(x)
y_branch1 = y_branch_loss(inception_4a)
y_branch2 = y_branch_loss(inception_4d)
model = Model([input_images], [y_main, y_branch1, y_branch2])
model.summary() # model.summary()表示将网络结构打印出来
将网络结构打印出来如下(自己对照一下上面提到的参数列表的输出shape看看对不对):
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_1 (InputLayer) [(None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv2d (Conv2D) (None, 112, 112, 64) 9472 input_1[0][0]
__________________________________________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 56, 56, 64) 0 conv2d[0][0]
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 56, 56, 64) 256 max_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_1 (Conv2D) (None, 56, 56, 64) 4160 batch_normalization[0][0]
__________________________________________________________________________________________________
conv2d_2 (Conv2D) (None, 56, 56, 192) 110784 conv2d_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 56, 56, 192) 768 conv2d_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_1 (MaxPooling2D) (None, 28, 28, 192) 0 batch_normalization_1[0][0]
__________________________________________________________________________________________________
conv2d_4 (Conv2D) (None, 28, 28, 96) 18528 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_6 (Conv2D) (None, 28, 28, 16) 3088 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
max_pooling2d_2 (MaxPooling2D) (None, 28, 28, 192) 0 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_3 (Conv2D) (None, 28, 28, 64) 12352 max_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_5 (Conv2D) (None, 28, 28, 128) 110720 conv2d_4[0][0]
__________________________________________________________________________________________________
conv2d_7 (Conv2D) (None, 28, 28, 32) 12832 conv2d_6[0][0]
__________________________________________________________________________________________________
conv2d_8 (Conv2D) (None, 28, 28, 32) 6176 max_pooling2d_2[0][0]
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 28, 28, 256) 0 conv2d_3[0][0]
conv2d_5[0][0]
conv2d_7[0][0]
conv2d_8[0][0]
__________________________________________________________________________________________________
conv2d_10 (Conv2D) (None, 28, 28, 128) 32896 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_12 (Conv2D) (None, 28, 28, 32) 8224 concatenate[0][0]
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D) (None, 28, 28, 256) 0 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_9 (Conv2D) (None, 28, 28, 128) 32896 concatenate[0][0]
__________________________________________________________________________________________________
conv2d_11 (Conv2D) (None, 28, 28, 192) 221376 conv2d_10[0][0]
__________________________________________________________________________________________________
conv2d_13 (Conv2D) (None, 28, 28, 96) 76896 conv2d_12[0][0]
__________________________________________________________________________________________________
conv2d_14 (Conv2D) (None, 28, 28, 64) 16448 max_pooling2d_3[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 28, 28, 480) 0 conv2d_9[0][0]
conv2d_11[0][0]
conv2d_13[0][0]
conv2d_14[0][0]
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D) (None, 14, 14, 480) 0 concatenate_1[0][0]
__________________________________________________________________________________________________
conv2d_16 (Conv2D) (None, 14, 14, 96) 46176 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_18 (Conv2D) (None, 14, 14, 16) 7696 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_5 (MaxPooling2D) (None, 14, 14, 480) 0 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_15 (Conv2D) (None, 14, 14, 192) 92352 max_pooling2d_4[0][0]
__________________________________________________________________________________________________
conv2d_17 (Conv2D) (None, 14, 14, 208) 179920 conv2d_16[0][0]
__________________________________________________________________________________________________
conv2d_19 (Conv2D) (None, 14, 14, 48) 19248 conv2d_18[0][0]
__________________________________________________________________________________________________
conv2d_20 (Conv2D) (None, 14, 14, 64) 30784 max_pooling2d_5[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 14, 14, 512) 0 conv2d_15[0][0]
conv2d_17[0][0]
conv2d_19[0][0]
conv2d_20[0][0]
__________________________________________________________________________________________________
conv2d_22 (Conv2D) (None, 14, 14, 112) 57456 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_24 (Conv2D) (None, 14, 14, 24) 12312 concatenate_2[0][0]
__________________________________________________________________________________________________
max_pooling2d_6 (MaxPooling2D) (None, 14, 14, 512) 0 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_21 (Conv2D) (None, 14, 14, 160) 82080 concatenate_2[0][0]
__________________________________________________________________________________________________
conv2d_23 (Conv2D) (None, 14, 14, 224) 226016 conv2d_22[0][0]
__________________________________________________________________________________________________
conv2d_25 (Conv2D) (None, 14, 14, 64) 38464 conv2d_24[0][0]
__________________________________________________________________________________________________
conv2d_26 (Conv2D) (None, 14, 14, 64) 32832 max_pooling2d_6[0][0]
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 14, 14, 512) 0 conv2d_21[0][0]
conv2d_23[0][0]
conv2d_25[0][0]
conv2d_26[0][0]
__________________________________________________________________________________________________
conv2d_28 (Conv2D) (None, 14, 14, 128) 65664 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_30 (Conv2D) (None, 14, 14, 24) 12312 concatenate_3[0][0]
__________________________________________________________________________________________________
max_pooling2d_7 (MaxPooling2D) (None, 14, 14, 512) 0 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_27 (Conv2D) (None, 14, 14, 128) 65664 concatenate_3[0][0]
__________________________________________________________________________________________________
conv2d_29 (Conv2D) (None, 14, 14, 256) 295168 conv2d_28[0][0]
__________________________________________________________________________________________________
conv2d_31 (Conv2D) (None, 14, 14, 64) 38464 conv2d_30[0][0]
__________________________________________________________________________________________________
conv2d_32 (Conv2D) (None, 14, 14, 64) 32832 max_pooling2d_7[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 14, 14, 512) 0 conv2d_27[0][0]
conv2d_29[0][0]
conv2d_31[0][0]
conv2d_32[0][0]
__________________________________________________________________________________________________
conv2d_34 (Conv2D) (None, 14, 14, 144) 73872 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_36 (Conv2D) (None, 14, 14, 32) 16416 concatenate_4[0][0]
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D) (None, 14, 14, 512) 0 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_33 (Conv2D) (None, 14, 14, 112) 57456 concatenate_4[0][0]
__________________________________________________________________________________________________
conv2d_35 (Conv2D) (None, 14, 14, 288) 373536 conv2d_34[0][0]
__________________________________________________________________________________________________
conv2d_37 (Conv2D) (None, 14, 14, 64) 51264 conv2d_36[0][0]
__________________________________________________________________________________________________
conv2d_38 (Conv2D) (None, 14, 14, 64) 32832 max_pooling2d_8[0][0]
__________________________________________________________________________________________________
concatenate_5 (Concatenate) (None, 14, 14, 528) 0 conv2d_33[0][0]
conv2d_35[0][0]
conv2d_37[0][0]
conv2d_38[0][0]
__________________________________________________________________________________________________
conv2d_40 (Conv2D) (None, 14, 14, 160) 84640 concatenate_5[0][0]
__________________________________________________________________________________________________
conv2d_42 (Conv2D) (None, 14, 14, 32) 16928 concatenate_5[0][0]
__________________________________________________________________________________________________
max_pooling2d_9 (MaxPooling2D) (None, 14, 14, 528) 0 concatenate_5[0][0]
__________________________________________________________________________________________________
conv2d_39 (Conv2D) (None, 14, 14, 256) 135424 concatenate_5[0][0]
__________________________________________________________________________________________________
conv2d_41 (Conv2D) (None, 14, 14, 320) 461120 conv2d_40[0][0]
__________________________________________________________________________________________________
conv2d_43 (Conv2D) (None, 14, 14, 128) 102528 conv2d_42[0][0]
__________________________________________________________________________________________________
conv2d_44 (Conv2D) (None, 14, 14, 128) 67712 max_pooling2d_9[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate) (None, 14, 14, 832) 0 conv2d_39[0][0]
conv2d_41[0][0]
conv2d_43[0][0]
conv2d_44[0][0]
__________________________________________________________________________________________________
max_pooling2d_10 (MaxPooling2D) (None, 7, 7, 832) 0 concatenate_6[0][0]
__________________________________________________________________________________________________
conv2d_46 (Conv2D) (None, 7, 7, 160) 133280 max_pooling2d_10[0][0]
__________________________________________________________________________________________________
conv2d_48 (Conv2D) (None, 7, 7, 32) 26656 max_pooling2d_10[0][0]
__________________________________________________________________________________________________
max_pooling2d_11 (MaxPooling2D) (None, 7, 7, 832) 0 max_pooling2d_10[0][0]
__________________________________________________________________________________________________
conv2d_45 (Conv2D) (None, 7, 7, 256) 213248 max_pooling2d_10[0][0]
__________________________________________________________________________________________________
conv2d_47 (Conv2D) (None, 7, 7, 320) 461120 conv2d_46[0][0]
__________________________________________________________________________________________________
conv2d_49 (Conv2D) (None, 7, 7, 128) 102528 conv2d_48[0][0]
__________________________________________________________________________________________________
conv2d_50 (Conv2D) (None, 7, 7, 128) 106624 max_pooling2d_11[0][0]
__________________________________________________________________________________________________
concatenate_7 (Concatenate) (None, 7, 7, 832) 0 conv2d_45[0][0]
conv2d_47[0][0]
conv2d_49[0][0]
conv2d_50[0][0]
__________________________________________________________________________________________________
conv2d_52 (Conv2D) (None, 7, 7, 192) 159936 concatenate_7[0][0]
__________________________________________________________________________________________________
conv2d_54 (Conv2D) (None, 7, 7, 48) 39984 concatenate_7[0][0]
__________________________________________________________________________________________________
max_pooling2d_12 (MaxPooling2D) (None, 7, 7, 832) 0 concatenate_7[0][0]
__________________________________________________________________________________________________
conv2d_51 (Conv2D) (None, 7, 7, 384) 319872 concatenate_7[0][0]
__________________________________________________________________________________________________
conv2d_53 (Conv2D) (None, 7, 7, 384) 663936 conv2d_52[0][0]
__________________________________________________________________________________________________
conv2d_55 (Conv2D) (None, 7, 7, 128) 153728 conv2d_54[0][0]
__________________________________________________________________________________________________
conv2d_56 (Conv2D) (None, 7, 7, 128) 106624 max_pooling2d_12[0][0]
__________________________________________________________________________________________________
concatenate_8 (Concatenate) (None, 7, 7, 1024) 0 conv2d_51[0][0]
conv2d_53[0][0]
conv2d_55[0][0]
conv2d_56[0][0]
__________________________________________________________________________________________________
average_pooling2d (AveragePooli (None, 1, 1, 1024) 0 concatenate_8[0][0]
__________________________________________________________________________________________________
average_pooling2d_1 (AveragePoo (None, 4, 4, 512) 0 concatenate_2[0][0]
__________________________________________________________________________________________________
average_pooling2d_2 (AveragePoo (None, 4, 4, 528) 0 concatenate_5[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 1, 1, 1024) 0 average_pooling2d[0][0]
__________________________________________________________________________________________________
conv2d_57 (Conv2D) (None, 4, 4, 128) 65664 average_pooling2d_1[0][0]
__________________________________________________________________________________________________
conv2d_58 (Conv2D) (None, 4, 4, 128) 67712 average_pooling2d_2[0][0]
__________________________________________________________________________________________________
flatten (Flatten) (None, 1024) 0 dropout[0][0]
__________________________________________________________________________________________________
dense_1 (Dense) (None, 4, 4, 1024) 132096 conv2d_57[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 4, 4, 1024) 132096 conv2d_58[0][0]
__________________________________________________________________________________________________
dense (Dense) (None, 1000) 1025000 flatten[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 4, 4, 1000) 1025000 dense_1[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 4, 4, 1000) 1025000 dense_3[0][0]
__________________________________________________________________________________________________
softmax (Softmax) (None, 1000) 0 dense[0][0]
__________________________________________________________________________________________________
softmax_1 (Softmax) (None, 4, 4, 1000) 0 dense_2[0][0]
__________________________________________________________________________________________________
softmax_2 (Softmax) (None, 4, 4, 1000) 0 dense_4[0][0]
==================================================================================================
Total params: 9,447,144
Trainable params: 9,446,632
Non-trainable params: 512
__________________________________________________________________________________________________
Process finished with exit code 0
到这里GooletNet卷积神经网络的搭建就全部完成啦,希望对你有帮助,文末附上我这个系列课程的内容大纲,欢迎关注我的动态一起学习哟~