tensorflow从已经训练好的模型中,恢复(指定)权重(构建新变量、网络)并继续训练(finetuning)
tensorflow模型的存储和恢复有两种模式:ckpt模式和PB模式
- ckpt 模式:
- 计算图和变量分开保存
- 读取模型时需要重新定义计算图,无需指明变量名
- pb 模式: 封装存储方案,隐藏模型结构
- 计算图和变量封装在一个文件中
- 无需重新定义计算图,但是需要指出变量名
一、ckpt模式
ckpt模型文件
训练了一个神经网络之后,我们希望保存它以便将来使用。那么什么是TensorFlow模型?Tensorflow模型主要包含我们所培训的网络参数的网络设计或图形和值。因此,Tensorflow模型有两个主要的文件:
a) Meta graph:
这是一个协议缓冲区,它保存了完整的graph;即所有变量、操作、集合等。该文件以.meta作为扩展名。
b) Checkpoint file:
这是一个二进制文件,它包含了所有的权重、偏差、梯度和其他所有变量的值。这个文件有一个扩展名.ckpt。然而,Tensorflow从0.11版本中改变了这一点。现在,我们有两个文件,而不是单个.ckpt文件:
mymodel.data-00000-of-00001
mymodel.index
c)checkpoint文件
Tensorflow也有一个名为checkpoint的文件,它只有最新保存的checkpoint文件的记录。多个.meta共用一个checkpoint文件
ckpt模型的两种保存方式和两种恢复方式
两种方式保存模型,
1.保存所有tensor,即整张图的所有变量,
2.只保存指定scope的变量
两种方式恢复模型,
1.导入模型的graph,用该graph的saver来restore变量
2.在新的代码段中写好同样的模型(变量名称及scope的name要对应),用默认的graph的saver来restore指定scope的变量
基于saver的两种保存方式:
1.保存整张图,所有变量
...
init = tf.global_variables_initializer()
saver = tf.train.Saver()
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init)
...
writer.add_graph(sess.graph)
...
saved_path = saver.save(sess,saver_path)
...
2.保存图中的部分变量
...
init = tf.global_variables_initializer()
#####################################################################################
#获取指定scope的tensor
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')]
#初始化saver时,传入一个var_list的参数
saver = tf.train.Saver(vgg_ref_vars)
#####################################################################################
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init)
...
writer.add_graph(sess.graph)
...
###############################################################
saved_path = saver.save(sess,saver_path)
##################################################################
...
基于saver的两种恢复方式:
1.导入graph来恢复
...
vgg_meta_path = params['vgg_meta_path'] # 文件路径,后缀是'.ckpt.meta'的文件
vgg_graph_weight = params['vgg_graph_weight'] # 文件路径,后缀是'.ckpt'的文件,里面是各个tensor的值
#########################导入graph########################################
# 导入graph到当前的默认graph中,返回导入graph的saver
saver_vgg = tf.train.import_meta_graph(vgg_meta_path)
#################################################################
x_vgg_feat = tf.get_collection('inputs_vgg')[0] #placeholder, [None, 4096],获取输入的placeholder
feat_decode = tf.get_collection('feat_encode')[0] #[None, 1024],获取要使用的tensor
"""
以上两个获取tensor的方式也可以为:
graph = tf.get_default_graph()
centers = graph.get_tensor_by_name('loss/intra/center_loss/centers:0')
当然,前提是有tensor的名字
"""
...
init = tf.global_variables_initializer()
saver = tf.train.Saver() # 这个是当前新图的saver
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init)
...
########################导入权重#####################################################
saver_vgg.restore(sess, vgg_graph_weight)#使用导入图的saver来恢复
#########################################################################
...
2.重写一样的graph,然后恢复指定scope的变量
######################用代码重建graph##########################################
def re_build():#重建保存的那个graph
with tf.variable_scope('vgg_feat_fc'): #没错,这个scope要和需要恢复模型中的scope对应
...
return ...
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
...
##################指定要恢复的变量scope###############################################
vgg_ref_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='vgg_feat_fc')
...
init = tf.global_variables_initializer()
saver_vgg = tf.train.Saver(vgg_ref_vars) # 这个是要恢复部分的saver
saver = tf.train.Saver() # 这个是当前新图的saver
config = tf.ConfigProto()
config.gpu_options.allow_growth=True
with tf.Session(config=config) as sess:
sess.run(init)
...
#############用指定路径的权重文件:vgg_graph_weight来恢复##############################
saver_vgg.restore(sess, vgg_graph_weight)#使用导入图的saver来恢复
...
总结一下,这里的要点就是,在restore的时候,saver要和模型对应,如果直接用当前graph的saver = tf.train.Saver(),来恢复保存模型的权重saver.restore(vgg_graph_weight),就会报错,提示key/tensor … not found之类的错误;
写graph的时候,一定要注意写好scope和tensor的name,合理插入variable_scope;
最方便的方式还是,用第1种方式来保存模型,这样就不用重写代码段了,然后第1种方式恢复,不过为了稳妥,最好还是通过获取var_list,指定saver的var_list,妥妥的!
最新发现,用第1种方式恢复时,要记得当前的graph和保存的模型中没有重名的tensor,否则当前graph的tensor name可能不是那个name,可能在后面加了"1"…--||
二、PB模式
PB 文件定义:
MetaGraph 的 protocol buffer 格式的文件,包括计算图,数据流,以及相关的变量等
PB 文件优点:
- 具有语言独立性,可独立运行,任何语言都可以解析
- 允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow的模型
- 保存为 PB 文件时候,模型的变量都会变成常量,使得模型的大小减小
- 可以把多个计算图保存到一个 PB 文件中
- 支持计算图的功能和使用设备命名区分多个计算图,例如 serving or training, CPU or GPU。
PB模式存储模型代码示例
import tensorflow as tf
from tensorflow.python.framework import graph_util
x = tf.Variable(tf.random_uniform([3]))
y = tf.Variable(tf.random_uniform([3]))
z = tf.add(x, y, name='op_to_store')
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(x))
print(sess.run(y))
print(sess.run(z))
constant_graph = graph_util. convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
with tf.gfile.FastGFile('./tmp2/pbplus.pb', mode=' wb') as f:
f.write(constant_graph.SerializeToString())
PB模式恢复模型代码示例
import tensorflow as tf
from tensorflow.python.platform import gfile
# ...... something disappeared ......
with tf.Session() as sess:
with gfile.FastGFile('./tmp2/pbplus.pb', 'rb') as f :
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
sess.run(tf.global_variables_initializer())
z = sess.graph.get_tensor_by_name('op_to_store
:0') # x? y?
print(sess.run(z))