使用tensorflow进行模型的保存很简单,生成一个saver实例然后saver.save()即可,在恢复时有两个问题很棘手:
- 怎么得到graph
- 怎么得到placeholder以创建feed_dictx
得到graph:
1.如果你有model的定义文件,可以像训练时一样重新创建一个图,然后saver.restore()获得图中的参数权重:
# 设置超参
hp = hparams
hp.is_training = False # 训练时为True,infer时为False否则infer结果不对
# 创建model
with tf.variable_scope('model') as scope:
model = create_model(args.model, hp)
model.build_graph()
model.add_loss()
model.add_decoder()
# 创建saver
saver = tf.train.Saver()
# 执行sess
with tf.Session(graph=tf.get_default_graph()) as sess: # 此时default_graph就是上面create_model过程中创建的图
saver.restore(sess, restore_path)
sess.run(tf.global_variables_initializer())
# 构造feed_dict
sess.run(其他)
2.如果你没有model的定义文件,别担心,由saver.save()获得的chekpoint文件(.index , .data , .meta)中.meta就保存了graph的结构,先使用 import_meta_graph 向saver中导入graph再进行restore:
# 创建saver
saver = tf.train.Saver()
# 运行sess
with tf.Session() as sess:
saver = tf.train.import_meta_graph('logging/logs-ASR_wavnet/model.ckpt-'+str(step)+'.meta')
saver.restore(sess, restore_path)
sess.run(tf.global_variables_initializer())
# 构造feed_dict
sess.run(其他)
构建feed_dict:
1.在构造feed_dict时你必须得知道input中的placehoder的信息(比如name,shape等),如果你有model的定义文件,这件事会很容易:
# 设置超参
hp = hparams
hp.is_training = False # 训练时为True,infer时为False否则infer结果不对
# 创建model
with tf.variable_scope('model') as scope:
model = create_model(args.model, hp)
model.build_graph()
model.add_loss()
model.add_decoder()
# 创建saver
saver = tf.train.Saver()
# 执行sess
with tf.Session(graph=tf.get_default_graph()) as sess: # 此时default_graph就是上面create_model过程中创建的图
saver.restore(sess, restore_path)
sess.run(tf.global_variables_initializer())
# 构造feed_dict
feed_dict = {model.input: [1]}
sess.run(其他)
2.但是如果你没有model的定义,如何找出上面的model.input这个tensor呢?
答案是通过tensoboard查看placeholder的名字,然后通过get_tensor_by_name获取该tensor:
# 创建saver
saver = tf.train.Saver()
# 运行sess
with tf.Session() as sess:
saver = tf.train.import_meta_graph('logging/logs-ASR_wavnet/model.ckpt-'+str(step)+'.meta')
saver.restore(sess, restore_path)
sess.run(tf.global_variables_initializer())
# 构造feed_dict
graph = sess.graph
inputs = graph.get_tensor_by_name('model/NET/NET_Input/mfcc_inputs:0')# 注意使用全名否则会报错:没有这个op
feed_dict = {inputs:[1]}
sess.run(其他)
参考: