tensorflow的模型保存和载入以及可视化

1.tensorflow的模型保存和读取

Tensorflow模型主要包含网络的设计(图)和训练好的各参数的值等。实际生成的Tensorflow模型有四个主要的文件:
1.cheakpoint文件,一个二进制的文件,仅用于保存最新的cheakpoint的记录。
2…data结尾的文件,包含了weights, biases, gradients和其他variables的值。
3…index结尾的文件,包含了weights, biases, gradients和其他variables的值。
4…meta结尾的文件,保存tensorflow完整的graph、variables、operation、collection。

在这里插入图片描述

1.普通方法

保存模型

首先需要建立一个save,然后在session中通过saver的save即将模型保存起来
saver=tf.train.Saver()
saver.save(sess,“save_path/file_name”)#file_name不存在会自动创建

import tensorflow as tf  
  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
result = v1 + v2  
saver = tf.train.Saver()  
with tf.Session() as sess:  
    sess.run(tf.global_variables_initializer())  
    saver.save(sess, "Model/model.ckpt")  

载入模型

载入模型也需要定义tensorflow计算图上的所有运算,并声明一个tf.train.Saver类,不过加载的时候不需要进行变量的初始化声明。
saver=tf.train.Saver()
saver.restore(sess,“save_path/file_name”)

import tensorflow as tf  
  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
result = v1 + v2  
saver = tf.train.Saver()  
with tf.Session() as sess:  
    saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"  
    print(sess.run(result)) # [ 3.] 

2.检查点(checkpoint)

保存模型的时候并不限于在训练之后,在训练的过程中也需要保存,因为Tensorflow在训练的过程中难免出现中断的情况,我们要使得我们训练的参数保留下来,否则下次还要重头开始训练。
这种在训练过程中保存模型,称之为保存检查点。
当中加入了saver = tf.train.Saver(max_to_keep=1)#表示表中最多只有一个检查点文件,在训练的过程中,新生成的模型会覆盖以前的模型。
在这里插入图片描述
saver=tf.train.Saver(),文中设置的是每隔两步生成模型。
在这里插入图片描述

添加并保存检查点
tf.reset_default_graph()
# 创建模型
# 占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
# 前向结构
z = tf.multiply(X, W)+ b
saver=tf.train.Saver()
savedir="log/"
load_epoch=18
with tf.Session() as sess2:
    sess2.run(tf.global_variables_initializer())
    saver.restore(sess2, savedir+"linermodel.cpkt-" + str(load_epoch))
    print ("x=0.2,z=", sess2.run(z, feed_dict={X: 0.2}))
   #生成的结果是x=0.2,z= [0.44503736]

3更简单的保存检查点

tf.train.MonitoredTrainingSession函数,是另一种保存检查点功能代码的方法,可以直接实现保存和载入检查点模型的文件,与前面不同的是,前面所介绍的是按训练步数保存的,这里是按照训练时间保存的。通过指定save_checkpoint_secs 参数额具体秒数,来设置多久保存一次检查点。
(1)如果不设置save_checkpoint_secs的参数,默认的保存时间间隔是10分钟。
(2)使用本方法必须定义global_step变量,否则会报错误。

import tensorflow as tf
tf.reset_default_graph()#重置图
global_step = tf.train.get_or_create_global_step()#初始化训练步数
step = tf.assign_add(global_step, 1)#每次加1
#设置检查点的路径,checkpoint_dir='log/checkpoints'
with tf.train.MonitoredTrainingSession(checkpoint_dir='log/checkpoints',save_checkpoint_secs  = 2) as sess:
    print(sess.run([global_step]))
    while not sess.should_stop():
        i = sess.run( step)
        print( i)

在这里插入图片描述
当再次运行的时候,直接从模型的保存结点开始,
在这里插入图片描述

2.Tensorflow可视化–Tensorboard

ensorboard是TensorFlow自带的一个强大的可视化工具,也是一个web应用程序套件。通过将tensorflow程序输出的日志文件的信息可视化使得tensorflow程序的理解、调试和优化更加简单高效。支持其七种可视化:

SCALARS:展示训练过程中的准确率、损失值、权重/偏置的变化情况
IMAGES:展示训练过程中及记录的图像
AUDIO:展示训练过程中记录的音频
GRAPHS:展示模型的数据流图,以及各个设备上消耗的内存和时间
DISTRIBUTIONS:展示训练过程中记录的数据的分布图
HISTOGRAMS:展示训练过程中记录的数据的柱状图
EMBEDDINGS:展示词向量后的投影分布
参考 https://blog.csdn.net/john_bh/article/details/80366596

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值