【tensorflow 大马哈鱼】高级保存与恢复的Supervisor模块

参考的教程

https://blog.csdn.net/YiRanZhiLiPoSui/article/details/81143166


 

参考入门文章: 
https://blog.csdn.net/u012436149/article/details/53341372 
给出了简单的完整流程,便于入门理解

https://www.jianshu.com/p/7490ebfa3de8 
tensorflow官网出的Supervisor介绍 的中文翻译版:长期训练好帮手

https://www.tensorflow.org/versions/r1.1/programmers_guide/supervisor 
tensorflow官网出的Supervisor介绍

https://www.tensorflow.org/api_docs/python/tf/train/Supervisor 
官方的Supervisor接口文档。不过缺乏完整的例子。

 


一、不使用Supervisor的情况

在不使用Supervisor的时候,我们的代码经常是这么组织的

variables
...
ops
...
summary_op
...
merge_op = tf.summary.merge_all() 
saver
init_op

with tf.Session() as sess:
  writer = tf.summary.FileWriter()
  sess.run(init)
  saver.restore()
  for ...:
    train
    merged_summary = sess.run(merge_op)
    writer.add_summary(merged_summary,i)
  saver.save

二、使用Supervisor的情况

使用一个logdir目录 来同时保存 模型图 和 权重参数

sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用来保存checkpoint和summary

注意有个参数是summary_op

如果没有summary_op=None,则使用Supervisor自带的summary服务

使用sv = tf.train.Supervisor()  会自动初始化。

无参数也可以,最好加上logdir,同时,两个logdir可以不同

import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa'

'''不需要初始化'''
#init_op = tf.global_variables_initializer()
#sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op) #logdir用来保存checkpoint和summary
'''这样也可以,最好加上logdir'''
sv = tf.train.Supervisor(logdir=logs_path)         #这样也可以

with sv.managed_session() as sess: #会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
    for i in range(71):
        update_ = sess.run(update)
        print(update_)
#        if i % 10 == 0:
#            merged_summary = sess.run(merged_summary_op)
#            sv.summary_computed(sess, merged_summary)
        if i%10 == 0:
            sv.saver.save(sess,logs_path+'/model',global_step=i)

如果有summary_op=None,则需自建summary服务

import tensorflow as tf
tf.reset_default_graph()
a = tf.Variable(1)
b = tf.Variable(2)
c = tf.add(a,b)
update = tf.assign(a,c)
logs_path='./logaa/'

tf.summary.scalar('a', a) 
init_op = tf.global_variables_initializer()
merged_summary_op = tf.summary.merge_all()  
sv = tf.train.Supervisor(logdir=logs_path,init_op=init_op,summary_op=None) #logdir用来保存checkpoint和summary

with sv.managed_session() as sess: #会自动去logdir中去找checkpoint,如果没有的话,自动执行初始化
    for i in range(1000):
        update_ = sess.run(update)
        print(update_)
        if i % 10 == 0:
            merged_summary = sess.run(merged_summary_op)
            sv.summary_computed(sess, merged_summary)
        if i%100 == 0:
            sv.saver.save(sess,logs_path,global_step=i)

一个完整的例子

# -*- coding: utf-8 -*-

import tensorflow as tf
tf.reset_default_graph()
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

##### 构建图结构
# 定义输入:x和y
x = tf.placeholder(tf.float32, [None, 784], name='input_x')
y_ = tf.placeholder(tf.float32, [None, 10], name='input_y')

# 定义权重参数
W = tf.Variable(tf.truncated_normal([784, 10], stddev=0.1), name='weights')
b = tf.Variable(tf.constant(0.1, shape=[10]), name='bias')

# 定义模型
y_output = tf.nn.softmax(tf.matmul(x, W) + b)

# 定义交叉熵
cross_entropy = -tf.reduce_sum(y_ * tf.log(y_output))
# 监控交叉熵
tf.summary.scalar('loss', cross_entropy)
# tf.summary.scalar('loss', cross_entropy, collections=['loss'])
# 定义优化器和训练器
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)

# 定义准确率的计算方式
# 取预测值和真实值 概率最大的标签
correct_prediction = tf.equal(tf.argmax(y_output,1), tf.argmax(y_,1))
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))


##### 构建会话
# 定义log保存路径
logs_path = 'logsbbb/'
# 定义summary node集合
merged_summary_op = tf.summary.merge_all()

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
# 定义Supervisor
sv = tf.train.Supervisor(logdir=logs_path, init_op=tf.global_variables_initializer(), summary_op=None)
with sv.managed_session(config=config) as sess :
    # 超参数
    ITERATION = 1000 +1
    BATCH_SIZE = 64
    ITERATION_SHOW = 100

    for step in range(ITERATION) :
        # 执行训练op
        batch = mnist.train.next_batch(BATCH_SIZE)
        sess.run(train_step, feed_dict={x: batch[0], y_: batch[1]})

        if step%ITERATION_SHOW == 0:
            # 计算当前训练样本的准确率
            merged_summary, accuracy = sess.run([merged_summary_op, accuracy_op], feed_dict={x: batch[0], y_: batch[1]})
            sv.summary_computed(sess, merged_summary, global_step=step)

            # 输出当前准确率
            print("step %d, accuarcy:%.4g" % (step, accuracy))

            # 保存模型
            sv.saver.save(sess, logs_path, global_step=step)

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值