TensorFlow之保存/恢复模型

模型保存

# save the specific variables
saver = tf.train.Saver(...variables list...)

# this will only save trainable variables
saver = tf.train.Saver(tf.trainable_variables()) 

# this will save all the variables in the graph
saver = tf.train.Saver()

关于上面保存变量选择的问题,我曾经有一个教训,就是在模型中使用了BN层,然后只保存了trainable variables,结果导致BN层学习到的moving mean/moving variance丢失,无法restore完整的模型了,所以如果不是很清楚,还是保存所有变量信息吧。

save_path = saver.save(sess, 'model_name', global_step=global_steps, write_meta_graph=True)

上面的model_without_meta=True表示是否保存.meta文件,也就是整个graph文件,出于保险考虑,觉得还是保存了好,默认也是True.

模型恢复

模型恢复分两种情形:

一种是不知道模型的graph,在此情形下就要借助上边模型保存时提到的.meta文件了,如果没有保存,那就gg了,大致的操作如下:

# 导入graph,并且在graph的基础上初始化一个saver(注意:saver必须依托一个graph来定义,无graph则无saver)
saver = tf.train.import_meta_graph('./my_model.meta')

graph = tf.get_default_graph()
graph_input = graph.get_tensor_by_name("graph_input:0")
graph_target = graph.get_tensor_by_name("graph_target:0")
hidden_state = graph.get_tensor_by_name("hidden_state:0")
is_training = graph.get_tensor_by_name("is_training:0")
hidden = graph.get_tensor_by_name("hidden:0")
output = graph.get_tensor_by_name("ouput:0")
loss = graph.get_tensor_by_name(options['loss']+':0')

# restore
saver.restore(sess, 'my_model')

至于上面出现的诸如graph_input:0之类的各种ops的名字,也要先知道才行,否则。。。只能自求多福了(不过可以通过分析.meta构造的graph来得知)

另一种就是预先已知了graph,这样就好办多了,直接先把graph构造起来,然后在构造好的graph上定义一个saver,然后就可以开心的restore了,比如:

# 这一句是用来构造graph的
graph_input, graph_target, hidden_state, is_training, hidden, output, loss, train_op = model_graph.graphConstruct(LR)

# 在graph上定义saver,restore所有的variables
saver = tf.train.Saver()

# restore
saver.restore(sess, 'my_model')

完!

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值