Tensorflow复习笔记5:高级保存与恢复的Supervisor模块

简介:Supervisor模块相比于Saver模块高级一些,封装了一些操作使得操作更简便。

关于tensorflow基础模块的文章很多了,详细介绍supervisor模块的文章基本没有。我自己探索了一些,放上自己的笔记,能帮到忙的话可以点个赞~


参考入门文章:
https://blog.csdn.net/u012436149/article/details/53341372
给出了简单的完整流程,便于入门理解

https://www.jianshu.com/p/7490ebfa3de8
tensorflow官网出的Supervisor介绍 的中文翻译版:长期训练好帮手

https://www.tensorflow.org/versions/r1.1/programmers_guide/supervisor
tensorflow官网出的Supervisor介绍

https://www.tensorflow.org/api_docs/python/tf/train/Supervisor
官方的Supervisor接口文档。不过缺乏完整的例子。


Supervisor的优点
  • 它建议使用一个logdir目录 来同时保存 模型图 和 权重参数。
  • 不用自己创建Saver 和 summary writer。summary部分还是建议手动操作。
  • 具体是通过global_step参数来自动管理训练进度。

流程demo:

global_step = tf.Variable(0, name='global_step', trainable=False)
# ...
merged_summary_op = tf.summary.merge_all()

# 定义Supervisor
sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer(), global_step=global_step, summary_op=None)
with sv.managed_session() as sess :
    # 超参数
    ITERATION = 10000 +1
    BATCH_SIZE = 64
    ITERATION_SHOW = 500

    while not sv.should_stop():
        step = sess.run(global_step)
        if step > ITERATION : break

        sess.run(train_step...)

        # 隔一段时间 检查一下准确率
        if step%ITERATION_SHOW == 0:
            merged_summary, accuracy = sess.run([merged_summary_op, accuracy_op]...)
            # write to summary
            sv.summary_computed(sess, merged_summary)

一些经验

关于global_step

supervisor模块中挺多地方都依赖于global_step变量。比如保存检查点、记录已训练次数。建议还是使用上global_step。
1. 定义global_step

global_step = tf.Variable(0, name='global_step', trainable=False)

这一句建议写在开始搭模型之前,要确保没有with包着它。也就是说它的name要是global_step:0
2. 在定义的train_step里 加上global_step参数

train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy_op, global_step=global_step)

不加上这句的话,supervisor会找不到global_step变量,然后保存的模型全部是0.meta
3. 在定义Supervisor里 加上global_step参数

sv = tf.train.Supervisor(logdir = logs_path,
                        init_op = tf.global_variables_initializer(), 
                        save_model_secs = 10, 
                        global_step = global_step, 
                        summary_op = None)

不加上这句的话,运行时会报如下警告。不过,运行结果都没啥问题…

WARNING:tensorflow:Error encountered when serializing global_step.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'Tensor' object has no attribute 'to_proto'

加上global_step之后,最直观的就是保存检查点的时候会加上后缀:model-1234。这也说明supervisor可以正常使用global_step了。

关于在supervisor模块中如何让 训练集和验证集的图像同框

图像同框的方法在上一篇文章中说明了,需要创建两个summaryWriter,分别write to summary就行了。不过sueprvisor模块中它会在内部创建summarywriter,那么如何同框呢?
经过分析supervisor的summary_computed函数源码后,发现其只是对saver模块的add_summary函数进行了简单封装。所以我找到了一个不很正规的解决方法。那就是手动实现了summary_computed的操作:为验证集另外创建一个summarywriter, 再add_summary就行了~ 大致是这样的:

# 定义验证集的FileWriter
validation_writer = tf.summary.FileWriter(os.path.join(logs_path,'validation'))
# 定义Supervisor
# ...
    # 计算当前训练集的准确率
    merged_summary, accuracy, loss = sess.run([merged_summary_op,     accuracy_op, cross_entropy_op], feed_dict={x: batch[0], y: batch[1],     keep_prob: 1.0})
    sv.summary_computed(sess, merged_summary, step)

    # 计算验证集的准确率
    (accuracy_validation_sum,
     loss_validation_sum,
     accuracy_validation, loss_validation) = sess.run([accuracy_scalar,    loss_scalar, accuracy_op, cross_entropy_op],    feed_dict={x:mnist.validation.images, y: mnist.validation.labels,    keep_prob: 1.0})
    # write to summary
    validation_writer.add_summary(accuracy_validation_sum, step)
    validation_writer.add_summary(loss_validation_sum, step)
全部的完整代码在github上

踩坑阶段

  1. 报错: (绕开的方式解决了)
    尝试在mnist的CNN模型上 加入supervisor模块替代Saver模块时,出现报错:
tensorflow.python.framework.errors_impl.InvalidArgumentError: You must feed a value for placeholder tensor 'input_y' with dtype float and shape [?,10]

触发条件:
我的Tensorflow版本 : 1.8.0
代码中使用了Supervisor ,并且定义了tf.summary节点,并且该节点依赖于某个placeholder节点。

本来以为是哪个feed_dict没写全,检查了一遍,feed_dict都没问题。
而且诡异的是这个报错是在session整个执行完了之后才出现的。也就是说其实没有影响到执行的结果。
报错的demo代码:mnist的softmax实现。

import tensorflow as tf

import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

##### 构建图结构
# 定义输入:x和y
x = tf.placeholder(tf.float32, [None, 784], name='input_x')
y_ = tf.placeholder(tf.float32, [None, 10], name='input_y')

# 定义权重参数
W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1), name='weights')
b = tf.Variable(tf.constant(0.1, shape=[10]), name='bias')

# 定义模型
y_output = tf.nn.softmax(tf.matmul(x, W) + b)

# 定义交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 监控交叉熵
tf.summary.scalar('loss', cross_entropy)
# tf.summary.scalar('loss', cross_entropy, collections=['loss'])
# 定义优化器和训练器
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 定义准确率的计算方式
# 取预测值和真实值 概率最大的标签
correct_prediction = tf.equal(tf.argmax(y_output,1), tf.argmax(y_,1))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


##### 构建会话
# 定义log保存路径
logs_path = 'logs/'
# 定义summary node集合
merged_summary_op = tf.summary.merge_all()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# 定义Supervisor
sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer())
with sv.managed_session(config=config) as sess :
    # 超参数
    ITERATION = 1000 +1
    BATCH_SIZE = 64
    ITERATION_SHOW = 100

    for step in range(ITERATION) :
        # 执行训练op
        batch = mnist.train.next_batch(BATCH_SIZE)
        sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})

        if step%ITERATION_SHOW == 0:
            # 计算当前训练样本的准确率
            merged_summary, accuracy = sess.run([merged_summary_op, accuracy_op], feed_dict={x: batch[0], y_: batch[1]})
            sv.summary_computed(sess, merged_summary, global_step=step)

            # 输出当前准确率
            print("step %d, accuarcy:%.4g" % (step, accuracy))

            # 保存模型
            sv.saver.save(sess, logs_path, global_step=step)
pass

原因:
貌似是因为Supervisor自带的summary服务和我们用tf.summary.scalar冲突了。
我尝试过把session里面的run全部注释掉,然而还是报这个错。我也没法了。。

其中一个解决方法:
很简单,把Supervisor自带的summary服务关掉即可:
定义Supervisor时,设置summary_op=None。

sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer(), summary_op=None)

相对应的坑随即出现了:如何使用Supervisor自带的summary服务呢?还是因为我本机的配置有问题?
有空再看看吧,暂时不折腾了。。

  • 1
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值