Tensorflow基础---保存模型,恢复模型

tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)

tensorflow模型的存储和恢复有两种模式:ckpt模式和PB模式

  1. ckpt 模式:
  • 计算图和变量分开保存
  • 读取模型时需要重新定义计算图,无需指明变量名
  1. pb 模式: 封装存储方案,隐藏模型结构
  • 计算图和变量封装在一个文件中
  • 无需重新定义计算图,但是需要指出变量名

一、ckpt模式

ckpt模型文件

训练了一个神经网络之后,我们希望保存它以便将来使用。那么什么是TensorFlow模型?Tensorflow模型主要包含我们所培训的网络参数的网络设计或图形和值。因此,Tensorflow模型有两个主要的文件:

a) Meta graph:
这是一个协议缓冲区,它保存了完整的graph;即所有变量、操作、集合等。该文件以.meta作为扩展名。

b) Checkpoint file:
这是一个二进制文件,它包含了所有的权重、偏差、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。然而,Tensorflow从0.11版本中改变了这一点。现在,我们有两个文件,而不是单个.ckpt文件:

 mymodel.data-00000-of-00001
 mymodel.index

c)checkpoint文件
Tensorflow也有一个名为checkpoint的文件,它只有最新保存的checkpoint文件的记录。多个.meta共用一个checkpoint文件

ckpt模型的两种保存方式和两种恢复方式

两种方式保存模型,
1.保存所有tensor,即整张图的所有变量,
2.只保存指定scope的变量

两种方式恢复模型,
1.导入模型的graph,用该graph的saver来restore变量
2.在新的代码段中写好同样的模型(变量名称及scope的name要对应),用默认的graph的saver来restore指定scope的变量

基于saver的两种保存方式:

1.保存整张图,所有变量
...
init = tf.global_variables_initializer()
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
  sess.run(init)
  ...
  writer.add_graph(sess.graph)
  ...
  saved_path = saver.save(sess,saver_path)
  ...
2.保存图中的部分变量
...
init = tf.global_variables_initializer()

#####################################################################################
#获取指定scope的tensor
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')]
#初始化saver时,传入一个var_list的参数
saver = tf.train.Saver(vgg_ref_vars)
#####################################################################################

config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
  sess.run(init)
  ...
  writer.add_graph(sess.graph)
  ...
  ###############################################################
  saved_path = saver.save(sess,saver_path)
  ##################################################################
  ...

基于saver的两种恢复方式:

1.导入graph来恢复
...
vgg_meta_path = params['vgg_meta_path'] # 文件路径,后缀是'.ckpt.meta'的文件
vgg_graph_weight = params['vgg_graph_weight'] # 文件路径,后缀是'.ckpt'的文件,里面是各个tensor的值

#########################导入graph########################################
# 导入graph到当前的默认graph中,返回导入graph的saver
saver_vgg = tf.train.import_meta_graph(vgg_meta_path) 
#################################################################


x_vgg_feat = tf.get_collection('inputs_vgg')[0] #placeholder, [None, 4096],获取输入的placeholder
feat_decode = tf.get_collection('feat_encode')[0] #[None, 1024],获取要使用的tensor
"""
以上两个获取tensor的方式也可以为:
graph = tf.get_default_graph()
centers = graph.get_tensor_by_name('loss/intra/center_loss/centers:0')
当然,前提是有tensor的名字
"""
...
init = tf.global_variables_initializer()
saver = tf.train.Saver() # 这个是当前新图的saver
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
  sess.run(init)
  ...
  ########################导入权重#####################################################
  saver_vgg.restore(sess, vgg_graph_weight)#使用导入图的saver来恢复
  #########################################################################
  ...
2.重写一样的graph,然后恢复指定scope的变量
######################用代码重建graph##########################################
def re_build():#重建保存的那个graph
  with tf.variable_scope('vgg_feat_fc'): #没错,这个scope要和需要恢复模型中的scope对应
  ...
  return ...
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
...
##################指定要恢复的变量scope###############################################
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc') 
...
init = tf.global_variables_initializer()
saver_vgg = tf.train.Saver(vgg_ref_vars) # 这个是要恢复部分的saver
saver = tf.train.Saver() # 这个是当前新图的saver
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
  sess.run(init)
  ...
  #############用指定路径的权重文件:vgg_graph_weight来恢复##############################
  saver_vgg.restore(sess, vgg_graph_weight)#使用导入图的saver来恢复
  ...

总结一下,这里的要点就是,在restore的时候,saver要和模型对应,如果直接用当前graph的saver = tf.train.Saver(),来恢复保存模型的权重saver.restore(vgg_graph_weight),就会报错,提示key/tensor … not found之类的错误;
写graph的时候,一定要注意写好scope和tensor的name,合理插入variable_scope;
最方便的方式还是,用第1种方式来保存模型,这样就不用重写代码段了,然后第1种方式恢复,不过为了稳妥,最好还是通过获取var_list,指定saver的var_list,妥妥的!

最新发现,用第1种方式恢复时,要记得当前的graph和保存的模型中没有重名的tensor,否则当前graph的tensor name可能不是那个name,可能在后面加了"1"…--||

二、PB模式

PB 文件定义:

MetaGraph 的 protocol buffer 格式的文件,包括计算图,数据流,以及相关的变量等

PB 文件优点:

  • 具有语言独立性,可独立运行,任何语言都可以解析
  • 允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow的模型
  • 保存为 PB 文件时候,模型的变量都会变成常量,使得模型的大小减小
  • 可以把多个计算图保存到一个 PB 文件中
  • 支持计算图的功能和使用设备命名区分多个计算图,例如 serving or training, CPU or GPU。

PB模式存储模型代码示例

import tensorflow as tf
from tensorflow.python.framework import graph_util
x = tf.Variable(tf.random_uniform([3]))
y = tf.Variable(tf.random_uniform([3]))
z = tf.add(x, y, name='op_to_store')
with tf.Session() as sess:
	sess.run(tf.global_variables_initializer())
	print(sess.run(x))
	print(sess.run(y))
	print(sess.run(z))
	constant_graph = graph_util.	convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
	with tf.gfile.FastGFile('./tmp2/pbplus.pb', mode='	wb') as f:
		f.write(constant_graph.SerializeToString())

PB模式恢复模型代码示例

import tensorflow as tf
from tensorflow.python.platform import gfile
# ...... something disappeared ......
with tf.Session() as sess:
	with gfile.FastGFile('./tmp2/pbplus.pb', 'rb') as f 	:
		graph_def = tf.GraphDef()
		graph_def.ParseFromString(f.read())
		sess.graph.as_default()
		tf.import_graph_def(graph_def, name='')
	sess.run(tf.global_variables_initializer())
		z = sess.graph.get_tensor_by_name('op_to_store
		:0') # x? y?
		print(sess.run(z))

基于SavedModel的存储与恢复

Tensorflow: 保存和复原模型(save and restore)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值