Tensorflow 如何存取网络模型

    当我训练完网络模型之后,会想到如何去保存训练好的weightsbias等网络参数,并在将来进行分类或者识别的任务中重新载入(restore)这个训练好的网络。那么在tensorflow中是如何实现对网络模型的保存的呢?
    在tensorflow中,变量存储在二进制文件中,主要包含从变量名到tensor值的映射关系。当创建一个Saver对象时,可以选择性地为检查点文件中的变量设置变量名。
    具体的,首先,给变量赋值,不过要在其后加上参数name=“”,注意,这里的name即要保存到网络模型的变量名称,未来在进行网络模型的载入时需要通过该变量值进行数据读取,类似字典的感觉。
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2”)
    之后,创建一个saver对象,来进行保存,同时不要忘记设定保存的路径。
saver = tf.train.Saver()
save_path = saver.save(sess, "./MNISTmodel/model.ckpt")
print ("Model saved in file: ", save_path)
    模型保存好之后,在需要再次使用这个模型时,同样需要再创建一个saver对象。不要忘记,要将模型中之前保存好的变量名称再赋给需要载入的模型,即
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name=“v2”)
不过此时不需要对这些变量进行初始化了
saver = tf.train.Saver()
......
with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "./MNISTmodel/model.ckpt")
  print "Model restored."
    这样就可以直接恢复之前训练好的模型了。经过我的验证,准确度与之前训练好的时刻准确度一致。证明网络模型确实被成功恢复了。
    模型的保存不仅为了将来再次使用它进行分类等任务,也可以用来做fine-tuning。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值