全连神经网络的经典实践——网络设计

接着前面的知识,本次主要介绍使用tensorflow实现一个完整的神经网络来解决MNIST手写数字识别任务。首先我们要先导入MNIST数据集(即上一节所讲内容,通过read_data_sets()导入数据集即可),之后定义网络中的相关参数,定义参数的代码如下:接着前面的知识,本次主要介绍使用tensorflow实现一个完整的神经网络来解决MNIST手写数字识别任务。

首先我们要先导入MNIST数据集(即上一节所讲内容,通过read_data_sets()导入数据集即可),之后定义网络中的相关参数,定义参数的代码如下:
在这里插入图片描述

其中learning_rate_decay用于滑动平均模型的参数,后面会讲解。相关参数设置完后,就可以初始化网络的权重参数和前向传播过程了,我们这次设计的网络是包含1个隐藏层的全连接结构,并在隐藏层输出之前通过ReLU激活函数进行非线性化。网络的前向传播过程定义在hidden_layer()函数中,x在运行会话时会feed图片数据, y_在运行会话时 feed 答案数据(label)。初始化权重和设计前向传播过程的代码如下:
在这里插入图片描述

为了在采用随机梯度下降算法训练神经网络时提高最终模型在测试数据上的表现, tensorflow 提供了一种在变量上使用滑动平均的方法,通常称之为滑动平均模型。实现滑动平均首先通过train.ExponentialMovingAverage()函数初始化一个滑动平均类,同时需要向函数提供一个衰减率(Decay)参数,这个衰减率将用于控制模型更新的速度。该模型会对每一个变量维护一个影子变量(shadow_variable),影子变量的初始值等于该变量的初始值,同时影子变量的值随着该变量的值的变化而变化,具体变化如下公式:

在这里插入图片描述
其中下角标为n的为更新前的值,n+1为更新后的值,shadow_variable是variable的影子变量,decay是前面提到的衰减率。从公式可以看到, decay 决定了滑动平均模型的更新速度,一般会设成非常接近1的数(如0.99或0.999 ) , decay值越大模型越趋于稳定。ExponentialMovingAverage()接受指定num_updates参数来限制decay的大小,如果在初始化时提供了num_updates参数,那么每次使用的衰减率decay值将由下式决定:
在这里插入图片描述

在得到初始化的滑动平均类之后,可以通过这个类的函数apply()提供要进行滑动平均计算的变量。这个函数的原型为apply(self,var_ list),其中var_list参数是一个传递进来的参数列表。ExponentialMovingAverage类中的average()负责执行影子变量的计算,参数为需要进行计算的原参数。接下来应用滑动平均模型获得各个变量的影子变量,补充代码如下:

在这里插入图片描述
在获得预测结果后,接下来我们需要设计损失函数以计算误差,这里我们使用的前面提到的交叉熵损失,想必大家也知道交叉熵损失经常与softmax分类层搭配使用,tensorflow提供了两者统一的封装函数——softmax_cross_entropy_with_logits(),我们这里使用一个“更严格”的版本——sparse softmax cross entropy with_ logits(),这个函数适合输入样本只能被划分为某一类的情况,这特别适用于 MNIST 手写字图片数据集的分类。sparse_softmax_cross_entropy_with_logits()的用法与前面提到的softmax_cross_entropy_with_logits()相同。在得到交叉熵损失后,我们可以计算权重参数的 L2 正则,并将正则损失和交叉熵损失糅合在一起计算总损失。接着我们还要考虑优化器(Optimizer),这里我们采用随机梯度下降法(SGD),学习率采用指数衰减的形式。在优化器类的minimize()函数指明了最小化的目标,这里设置为总损失。计算损失和建立优化器的代码如下:

在这里插入图片描述
为了打印模型预测出的准确率,还需要定义一些这方面的相关操作。Equal()函数能够通过True或False值返回比较之后的结果。由于一个batch中包含多个样本,所以在得到这个比较的结果后,我们还需要对其进行数值类型转化以及在所有样本上求平均操作。代码如下:

在这里插入图片描述
以上代码执行完后,我们就已经搭建完神经网络了,接下来要做的就是创建回话开始正式的训练过程。首先我们需要初始化所有定义的变量,然后通过for循环进行迭代训练,在循环中,我们使用上一节提到的read_data_sets()函数获取MNIST数据集,并调用该其next_batch()函数获取一个batch数据,并feed给x和y_。循环中还要做的就是执行模型训练结果在验证数据集上的验证,我们设定这个操作发生在循环的轮数是 1000 的倍数时。验证不需要每一轮都取一个batch,所以在 for 循环外定义了validate_feed用于验证过程中将数据feed到网络模型中。Test_feed是我们的测试数据,用于测试最终的精确度。会话部分的代码如下:
在这里插入图片描述

最终的输出结果如下:

在这里插入图片描述
从打印的信息可以看出,在训练刚开始的时候,模型在验证集的准确率只有12%左右,但是当训练迭代至 1000 轮时,模型在验证集的准确率猛增至98%之后,在漫长的迭代过程中,模型在验证集的准确率始终保持在98%左右的水平,我们可以称此时模型遇到了“瓶颈”。最终准确率停在了98.5%左右,除非不改进模型,不然已经没有继续训练下去的必要了,不过现在的准确率已经是一个非常高的水平了。

发布了31 篇原创文章 · 获赞 2 · 访问量 1477
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 数字20 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览