基于TensorFlow的FNN模型——MNIST手写数字识别器(四)之训练和测试【开源】

注意:这是一个完整的项目,建议您按照完整的博客顺序阅读。

文章目录

前言

一、训练神经网络的方法

1、训练主流程

2、定期保存网络模型

3、在TensorBoard中可视化

4、在TensorBoard中记录各个节点的训练信息(可选)

二、评估与测试模型

1、加载模型

(1)重定义计算图结构

(2)滑动平均值重命名变量

(3)读取保存的网络模型

2、随机抽取图片检测预测效果

提示


前言

之前的几篇博客我们已经定义好了如何进行定义计算图模型,接下来我们就要进行计算图模型的训练和评测。


一、训练神经网络的方法

1、训练主流程

首先搭建训练计算图模型的主流程,使用之前的_define_graph()定义好该计算图

接着在一个会话中初始化所有的变量:tf.global_variables_initializer().run()  # 初始化所有变量

最后就可以写一个训练循环来进行训练了,这里最关键的是:sess.run( [self.optimizer],feed_dict={self.x_input: xs,self.y_input: ys})

执行优化器操作才能让网络进行反向传播来优化自身的参数

为了可以快速查看监控模型,我们可以每隔10次输出一次损失值和此时的循环数,方法如下:

    def train(self,train):
        """ 训练一个计算图模型"""
        self._define_graph()
        # 开始训练
        with tf.Session() as sess:
            tf.global_variables_initializer().run()  # 初始化所有变量
            for i in range(1, self.training_step + 1):
                xs, ys = train.next_batch(self.batch_size) # 获取一个批次
                if i % 10 == 0:
                    _, loss_value, step= sess.run([self.optimizer, self.loss, self.global_step],feed_dict={self.x_input: xs, self.y_input: ys})
                    print('Epoich: %d , loss: %g.' % (step, loss_value))
                else:
                    _= sess.run( [self.optimizer],feed_dict={self.x_input: xs,self.y_input: ys})# 优化参数
            train_writer.close()

2、定期保存网络模型

为了可以将我们的网络模型保存起来,我们可以定期将网络模型保存下:

关于保存网络模型的知识,读者可以参考网络资料或者我的相关博客TensorFlow模型的保存与加载(一)——checkpoint模式【源码】

    def train(self,train):
        """ 训练一个计算图模型"""
        self._define_graph()
        saver = tf.train.Saver()# 网络模型保存器
        # 开始训练
        with tf.Session() as sess:
            tf.global_variables_initializer().run()  # 初始化所有变量
            for i in range(1, self.training_step + 1):
                xs, ys = train.next_batch(self.batch_size) # 获取一个批次
                # 定期保存网络
                if i % 1000 == 0:
                    saver.save(sess, os.path.join(self.model_save_path, self.model_name), global_step=self.global_step)  # 保存cnpk模型
                    _, loss_value, step = sess.run([self.optimizer, self.loss, self.global_step], feed_dict={self.x_input: xs, self.y_input: ys})
                    print('Epoich: %d , loss: %g. and save model successfully' % (step, loss_value))
                # 定期打印信息和记录变量
                elif i % 10 == 0:
                    # 直接执行优化器、损失值和step和合并操作
                    _, loss_value, step= sess.run([self.optimizer, self.loss, self.global_step],feed_dict={self.x_input: xs, self.y_input: ys})
                    print('Epoich: %d , loss: %g.' % (step, loss_value))
                    train_writer.add_summary(summary, i) # 添加到graph event文件中用于TensorBoard的显示
                else:
                    _, step = sess.run( [self.optimizer, self.global_step],feed_dict={self.x_input: xs, self.y_input: ys})# 优化参数
            train_writer.close()

3、在TensorBoard中可视化

在前面定义计算图时,我们大量使用了命名空间和监控变量的操作,都是为了在TensorBoard中可视化。

在TensorBoard中进行可视化要先使用:

merged_summary_op = tf.summary.merge_all() # 合并所有的summary为一个操作节点,方便运行

这样我们只需要执行这一个操作节点。更多可视化知识可以参考我的博客TensorBoard训练可视化(一)

在添加了可视化部分代码后,我们的训练代码如下:

 def train(self,train):
        """ 训练一个计算图模型"""
        self._define_graph()
        merged_summary_op = tf.summary.merge_all() # 合并所有的summary为一个操作节点,方便运行
        saver = tf.train.Saver()# 网络模型保存器
        # 开始训练
        with tf.Session() as sess:
            tf.global_variables_initializer().run()  # 初始化所有变量
            train_writer = tf.summary.FileWriter(self.logs_save_path, sess.graph) # 文件输出对象,用于生成graph event文件
            for i in range(1, self.training_step + 1):
                xs, ys = train.next_batch(self.batch_size) # 获取一个批次
                # 定期保存网络
                if i % 1000 == 0:
                    saver.save(sess, os.path.join(self.model_save_path, self.model_name), global_step=self.global_step)  # 保存cnpk模型
                    _, loss_value, step = sess.run([self.optimizer, self.loss, self.global_step], feed_dict={self.x_input: xs, self.y_input: ys})
                    train_writer.add_run_metadata(run_metadata, 'step%03d' % i) # #将节点在运行时的信息写入日志文件
                    print('Epoich: %d , loss: %g. and save model successfully' % (step, loss_value))
                # 定期打印信息和记录变量
                elif i % 10 == 0:
                    # 直接执行优化器、损失值和step和合并操作
                    _, loss_value, step, summary = sess.run([self.optimizer, self.loss, self.global_step, merged_summary_op],feed_dict={self.x_input: xs, self.y_input: ys})
                    print('Epoich: %d , loss: %g.' % (step, loss_value))
                    train_writer.add_summary(summary, i) # 添加到graph event文件中用于TensorBoard的显示
                else:
                    _ = sess.run( [self.optimizer],feed_dict={self.x_input: xs, self.y_input: ys})# 优化参数
            train_writer.close()

4、在TensorBoard中记录各个节点的训练信息(可选)

在利用tensorflow写程序时,我们常常会碰到GPU利用率始终不高的情况,这时我们需要详细了解程序结点的消耗时间。

tensorboard提供记录各部分op的时间消耗,可以帮助开发者了解程序的瓶颈:

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)#配置运行时需要记录的信息
run_metadata = tf.RunMetadata()#运行时记录运行信息的proto
sess.run(res, options=options, run_metadata=run_metadata)

二、评估与测试模型

1、加载模型

(1)重定义计算图结构

在评测模型时,我们一般假设我们是使用保存的网络模型,而不是内存中正在训练的模型,这样更符合真实的业务场景。

在加载模型之前,我们要先定义计算图的关键节点(也可以通过变量名来获取),如下我们重建了计算图的输入和输出节点:

x_input = tf.placeholder(tf.float32, [None, self.n_input], name='x-input')
y_input = tf.placeholder(tf.float32, [None, self.n_output], name='y-input')
output = self._define_net(x_input , regularizer__function=None, is_historgram=False)

注意,这里我们重定义的输入输出结构要尽可能与原来定义的变量名及其name一致。

(2)滑动平均值重命名变量

前面我们介绍了我们使用了一种滑动平均值技术来维护权重参数,这种技术能够使我们的模型更加健壮。

滑动平均值本质上使用一个小范围平均值来代替实际的权重参数,这样可以缓解权重参数过于敏感。

在训练时我们使用真实的权重参数进行训练,同时借助真实的权重参数计算出一个较稳定的影子权重。

在测试时我们就要使用这个影子权重来代替真实的权重参数,使得我们的模型在未知数据集上更稳定,这就需要我们加载模型时对变量进行重命名。

如下,我们需要定义一个Saver对象:

saver = tf.train.Saver({“v1”:v2}) # 将保存模型中name=v1的变量加载到变量v2中,v2的name可以自定义

tf可以直接帮助我们重命名为滑动平均值的方法,如下:

# 滑动平均变量
variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)  # 定义一个滑动
variables_to_restore = variables_averages.variables_to_restore()  # 生成变量重命名的列表
# 创建加载变量重命名后的保存器
saver = tf.train.Saver(variables_to_restore)

(3)读取保存的网络模型

    def _run_saved_model(self,images,labels):
        # 加载模型
        with tf.Graph().as_default() as g:
            x_input = tf.placeholder(tf.float32, [None, self.n_input], name='x-input')
            y_input = tf.placeholder(tf.float32, [None, self.n_output], name='y-input')
            output = self._define_net(x_input, regularizer__function=None, is_historgram=False)
            # 滑动平均变量
            variables_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)  # 定义一个滑动平均类
            variables_to_restore = variables_averages.variables_to_restore()  # 生成变量重命名的列表
            # 创建加载变量重命名后的保存器
            saver = tf.train.Saver(variables_to_restore)
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(self.model_save_path)  # 获取ckpt的模型文件的路径
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess, ckpt.model_checkpoint_path)  # 恢复模型参数
                    pred = sess.run(output, feed_dict= {x_input: images, y_input: labels})  # 运行计算图,获取准确率
                    return pred

                else:
                    print('No checkpoint file found')
                    return None

2、随机抽取图片检测预测效果

    def test_random(self,_images,_labels):
        # 随机挑选9个照片
        random_indices = random.sample(range(len(_images)), min(len(_images), 9))
        images, labels = zip(*[(_images[i], _labels[i]) for i in random_indices])
        # 加载模型
        pred = self._run_saved_model(images,labels)
        if pred is not None:
            datahelpter.plot_images(images=images, cls_true=np.argmax(labels, 1), cls_pred=np.argmax(pred, 1),img_size=28, num_channels=1)

 


提示

如果本项目对您的学习有帮助,欢迎点赞支持!

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

魔法攻城狮MRL

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值