DL with python(14)——tensorflow实现CNN的“八股”

本文涉及到的是中国大学慕课《人工智能实践:Tensorflow笔记》第五讲第10节的内容,对tensorflow环境下卷积神经网络的搭建步骤进行介绍,其代码也是后续几种典型卷积神经网络介绍的基础。

卷积神经网络“八股”

DL with python(4)——基于Keras的二层神经网络鸢尾花分类中,已经介绍了的神经网络搭建的“六步法”,

在此基础上,可以总结出在 Tensorflow 框架下,利用 Keras 来搭建神经网络的“八股”套路,在主干的基础上,还可以添加其他内容,来完善神经网络的功能,如利用自己的图片和标签文件来自制数据集;通过旋转、缩放、平移等操作对数据集进行数据增强;保存模型文件进行断点续训;提取训练后得到的模型参数以及准确率曲线,实现可视化等。

构建神经网络的“八股”套路:

  1. import 引入 tensorflow 及 keras、numpy 等所需模块。
  2. 读取数据集,课程中所利用的 MNIST、cifar10 等数据集比较基础,可以直接从sklearn等模块中引入,但是在实际应用中,大多需要从图片和标签文件中读取所需的数据集。
  3. 搭建所需的网络结构,当网络结构比较简单时,可以利用 keras 模块中的tf.keras.Sequential 来搭建顺序网络模型;但是当网络不再是简单的顺序结构,而是有其它特殊结构出现时(例如 ResNet 中的跳连结构),便需要利用 class来定义自己的网络结构。前者使用起来更加方便,但实际应用中往往需要利用后者来搭建网络。
  4. 对搭建好的网络进行编译(compile),通常在这一步指定所采用的优化器(如 Adam、sgd、RMSdrop 等)以及损失函数(如交叉熵函数、均方差函数等),选择哪种优化器和损失函数往往对训练的速度和效果有很大的影响,至于具体如何进行选择,需要根据具体情况而定(其实我还没有研究)。
  5. 将数据输入编译好的网络来进行训练(model.fit),在这一步中指定训练轮数 epochs 以及 batch_size等信息,由于神经网络的参数量和计算量一般都比较大,训练所需的时间也会比较长,尤其是在硬件条件受限的情况下,所以在这一步中通常会加入断点续训以及模型参数保存等功能,使训练更加方便,同时防止程序意外停止导致数据丢失的情况发生。
  6. 将神经网络模型的具体信息打印出来(model.summary),包括网络结构、网络各层的参数等,便于对网络进行浏览和检查。

通过class类搭建CNN结构的方法

这里搭建的CNN结构如图所示,含有1个卷积过程(卷积、批归一化、激活、池化、Dropout),2个全连接层用于输出结果。
在这里插入图片描述
代码的实现十分简单,这也正体现了tensorflow方便快捷的优点,通过class构建网络是六步法中的第三步。具体的注释在代码中,大部分结构和前面的类似。

## 第一步,导入相关模块
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model

np.set_printoptions(threshold=np.inf)
## 第二步,导入数据集
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
## 第三步,搭建网络结构
# class类搭建卷积神经网络
class Baseline(Model):  # 类的名字是Baseline,继承了tensorflow的Model类
    def __init__(self): # 定义网络结构
        super(Baseline, self).__init__() # 固定格式,注意类名
        self.c1 = Conv2D(filters=6, kernel_size=(5, 5), padding='same')  # 卷积层,6个卷积核,大小为5*5,全零填充
        self.b1 = BatchNormalization()  # BN层,进行批归一化
        self.a1 = Activation('relu')  # 激活层
        self.p1 = MaxPool2D(pool_size=(2, 2), strides=2, padding='same')  # 池化层,2*2卷积核,步长2,全零填充
        self.d1 = Dropout(0.2)  # dropout层,舍弃20%的输入

        self.flatten = Flatten()   #拉直
        self.f1 = Dense(128, activation='relu')
        self.d2 = Dropout(0.2)
        self.f2 = Dense(10, activation='softmax')
    # class类的调用?
    def call(self, x): # 调用前面定义的各层,实现x到y的前向传播
        x = self.c1(x)
        x = self.b1(x)
        x = self.a1(x)
        x = self.p1(x)
        x = self.d1(x)

        x = self.flatten(x)
        x = self.f1(x)
        x = self.d2(x)
        y = self.f2(x)
        return y

model = Baseline()  # 实例化model
## 第四步,配置训练方法,选择优化器、损失函数、评测指标
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])
# 导入保存的模型,第二次运行才可以进行的操作,实现断点续训
checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)
# 保存模型,第一次运行执行这一步操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,
                                                 save_best_only=True)
## 第五步,执行训练,依次为训练集样本,训练集标签,小批量大小32,训练轮次5,测试集,训练集循环1轮次进行一次测试,进行断点续训
history = model.fit(x_train, y_train, batch_size=32, epochs=20, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
## 第六步,打印网络结构和参数统计
model.summary()
# 将模型参数保存到txt文档中
file = open('./weights.txt', 'w')
for v in model.trainable_variables:
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

###############################################    show   ###############################################
# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']            # 训练集准确率
val_acc = history.history['val_sparse_categorical_accuracy']    # 测试集准确率
loss = history.history['loss']                                  # 训练集损失值
val_loss = history.history['val_loss']                          # 测试集损失值

plt.subplot(1, 2, 1)                            # 将2个图画在一个界面中,第1个图
plt.plot(acc, label='Training Accuracy')        # 绘制训练集准确率
plt.plot(val_acc, label='Validation Accuracy')  # 绘制测试集准确率
plt.title('Training and Validation Accuracy')   # 设置标题
plt.legend()                                    # 画出图例

plt.subplot(1, 2, 2)                            # 将2个图画在一个界面中,第2个图
plt.plot(loss, label='Training Loss')           # 绘制训练集准损失值
plt.plot(val_loss, label='Validation Loss')     # 绘制测试集损失值
plt.title('Training and Validation Loss')       # 设置标题
plt.legend()                                    # 画出图例
plt.show()                                      # 输出图片

运行结果

绘制的相关曲线如图
在这里插入图片描述
还会在代码文件所属的文件夹中保存一个“weights.txt”文件,其中是本次训练的所有参数值,用于下一次的断点续训,即在本次运行结果的基础上继续训练,进一步优化网络参数,提高性能。
下面是第二次运行后的网络性能曲线,可以看到准确率有了很大的提升。
在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值