[tensorflow] 模型的恢复

使用tensorflow进行模型的保存很简单,生成一个saver实例然后saver.save()即可,在恢复时有两个问题很棘手:

  1. 怎么得到graph
  2. 怎么得到placeholder以创建feed_dictx

得到graph:

1.如果你有model的定义文件,可以像训练时一样重新创建一个图,然后saver.restore()获得图中的参数权重:

# 设置超参
hp = hparams
hp.is_training = False # 训练时为True,infer时为False否则infer结果不对

# 创建model
with tf.variable_scope('model') as scope:
    model = create_model(args.model, hp)
    model.build_graph()
    model.add_loss()
    model.add_decoder()

# 创建saver
saver = tf.train.Saver()

# 执行sess
with tf.Session(graph=tf.get_default_graph()) as sess: # 此时default_graph就是上面create_model过程中创建的图
    saver.restore(sess, restore_path)
    sess.run(tf.global_variables_initializer())
    # 构造feed_dict
    sess.run(其他)

2.如果你没有model的定义文件,别担心,由saver.save()获得的chekpoint文件(.index , .data , .meta)中.meta就保存了graph的结构,先使用 import_meta_graph 向saver中导入graph再进行restore:

# 创建saver
saver = tf.train.Saver()

# 运行sess
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('logging/logs-ASR_wavnet/model.ckpt-'+str(step)+'.meta')
    saver.restore(sess, restore_path)
    sess.run(tf.global_variables_initializer())
    # 构造feed_dict
    sess.run(其他)

构建feed_dict:

1.在构造feed_dict时你必须得知道input中的placehoder的信息(比如name,shape等),如果你有model的定义文件,这件事会很容易:

# 设置超参
hp = hparams
hp.is_training = False # 训练时为True,infer时为False否则infer结果不对

# 创建model
with tf.variable_scope('model') as scope:
    model = create_model(args.model, hp)
    model.build_graph()
    model.add_loss()
    model.add_decoder()

# 创建saver
saver = tf.train.Saver()

# 执行sess
with tf.Session(graph=tf.get_default_graph()) as sess: # 此时default_graph就是上面create_model过程中创建的图
    saver.restore(sess, restore_path)
    sess.run(tf.global_variables_initializer())
    # 构造feed_dict
    feed_dict = {model.input: [1]}
    sess.run(其他)

2.但是如果你没有model的定义,如何找出上面的model.input这个tensor呢?

答案是通过tensoboard查看placeholder的名字,然后通过get_tensor_by_name获取该tensor

 

# 创建saver
saver = tf.train.Saver()

# 运行sess
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('logging/logs-ASR_wavnet/model.ckpt-'+str(step)+'.meta')
    saver.restore(sess, restore_path)
    sess.run(tf.global_variables_initializer())
    # 构造feed_dict
    graph = sess.graph
    inputs = graph.get_tensor_by_name('model/NET/NET_Input/mfcc_inputs:0')# 注意使用全名否则会报错:没有这个op
    feed_dict = {inputs:[1]}
    sess.run(其他)

参考:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值