1、tf.train.Saver类
tensorflow主要通过train.Saver类来保存和提取模型,该类定义在tensorflow/python/training/saver.py中
Saver的初始化参数如下:
__init__(self,
var_list=None, #一个字典,指定保存的对象列表,默认为None,即保存所有可保存对象
reshape=False, #当为True时,表示从一个checkpoint中恢复参数时允许参数shape发生变化
sharded=False, #是否将变量轮循放到所有设备上
max_to_keep=5, #保存模型时会滚动更新,该值指定保存的模型个数
keep_checkpoint_every_n_hours=10000.0, #按时间间隔来保存模型
name=None,
restore_sequentially=False, #是否按顺序恢复变量
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
2、保存模型与参数
过程为:定义计算图、执行初始化和计算、保存计算后的计算图和参数
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
saver.save(sess, "model/model.ckpt")
其中,save的可选参数如下:
save(
sess, #必须是加载了计算图、且变量已经初始化的session
save_path, #模型的路径
global_step=None, #如果提供,会添加在save_path后面,以区分不同阶段的模型
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)
执行后,会在model文件夹下得到四个文件:
其中,checkpoint保存模型的列表,model.ckpt.meta文件保存了计算图的结构信息,model.cpkt.index和model.cpkt.data保存的是参数名和参数值。(旧版的tf保存的数据只有一个cpkt文件,而新版的tf把它分成了两个文件)
3、模型的读取与恢复
根据checkpoint文件来寻找最新的模型:
ckpt = tf.train.get_checkpoint_state('./model/') #锁定最新模型
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') #model_checkpoint_path: ./model/model.ckpt-4
with tf.Session() as sess:
saver.restore(sess,ckpt.model_checkpoint_path)
恢复模型的过程为:定义计算图、恢复参数、执行计算,相当于用restore取代了初始化
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "model/model.ckpt")
print(sess.run(result))
也可以直接恢复checkpoint中的计算图,图中的变量通过变量名来获得
saver = tf.train.import_meta_graph("model/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "model/model.ckpt")
result = tf.get_default_graph().get_tensor_by_name("add:0")
print(sess.run(result))
在恢复时,默认把变量值指定给同名的变量,若计算图中的变量在checkpoint中不存在,则会报NotFoundError。(checkpoint中的数据能多不能少)
也可以手动把checkpoint中的变量值指定给计算图中不同名的变量:
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
result = u1 + u2
saver = tf.train.Saver({"v1": u1, "v2": u2}) #"v1"、"v2"是ckpt中的变量名,u1、u2是当前环境中的变量
with tf.Session() as sess:
saver.restore(sess, "model/model.ckpt")
print(sess.run(result))
其中,可以通过以下方法来读取checkpoint中的变量名和变量值:
ckpt = tf.train.get_checkpoint_state('./model')
checkpoint_path = ckpt.model_checkpoint_path
# Read data from checkpoint file
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
# Print tensor name and values
for key in var_to_shape_map:
print("tensor_name: ", key)
print("tensor_value: ", reader.get_tensor(key))
获取图中的某个operation:
可以使用sess.graph.get_operations()来获取图中所有的operation
#定义operation时,设置好‘scope’和‘name’,然后保存
with tf.variable_scope('fc2'):
logit = tf.add(tf.matmul(fc1, weight_fc2), bias_fc2, name='logit')
#恢复时,先恢复graph,然后通过sess.graph.get_operation_by_name来获取operation
ckpt = tf.train.get_checkpoint_state('model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
logit = sess.graph.get_operation_by_name("fc2/logit")
print logit
获取图中某个Tensor
#从保存的模型中恢复计算图和参数,并获取输入和输出,输入待测试数据进行预测
#保存时设置scope和name:
#with tf.variable_scope('placeholder'):
# image = tf.placeholder(tf.float32, [None, 784], name='image')
# label = tf.placeholder(tf.int32, [None,], name='label')
#with tf.variable_scope('fc2'):
logit = tf.add(tf.matmul(drop, weight_fc2), bias_fc2, name='logit')
ckpt = tf.train.get_checkpoint_state('model/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
graph = sess.graph
image = graph.get_tensor_by_name('placeholder/image:0') #输入:placeholder
label = graph.get_tensor_by_name('placeholder/label:0') #输入:placeholder
logit = graph.get_tensor_by_name('fc2/logit:0') #输出: logit
predict = sess.run(logit, feed_dict={image: batch_image, label: batch_label})
只恢复部分变量:
variables_to_restore = tf.contrib.framework.get_variables_to_restore(
exclude=['logit'])
init_assign_op, init_feed_dict = tf.contrib.framework.assign_from_checkpoint(
ckpt.model_checkpoint_path, variables_to_restore)
def InitAssignFn(scaffold, sess):
sess.run(init_assign_op, init_feed_dict)
scaffold = tf.train.Scaffold(saver=tf.train.Saver(), init_fn=InitAssignFn)