TensorFlow入门(六、模型的保存和载入)

本文介绍了如何在TensorFlow中使用saver类保存和加载模型,包括基本的save方法、指定变量映射的高级用法以及load的restore方法。
摘要由CSDN通过智能技术生成

保存模型

使用TensorFlow的saver()类先实例化一个saver对象,然后在session中通过saver的save方法将模型保存起来。代码示例如下:

#初始化所有变量
init = tf.global_variable_initializer()

#定义saver和保存路径
saver = tf.train.Saver()
saverdir = "save_path"

#启动Session
with tf.Session() as sess:
    sess.run(init)
    #使用saver的save方法保存
    saver.save(sess,saverdir + "file_name")

        其中,filename如果不存在,程序会自动创建。

打印模型中的内容

使用inspect_checkpoint包中的print_tensors_in_checkpoint_file方法将模型中的具体内容打印出来。代码示例如下:

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
form tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

saverdir = "log/"
print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)

保存模型的其他方法

使用saver()类保存模型时,可以在函数中放入参数来实现更高级的功能,如指定存储变量名字与变量的对应关系。代码示例如下:

W = tf.Variable(1.0,name = "weight")
b = tf.Variable(2.0,name = "bias")

saver = tf.train.Saver({'weight':W,'bias':b})
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    saver.save(sess,savedir + "linearmodel.cpkt")
print_tensors_in_checkpoint_file(savedir + "linearmodel.cpkt",None,True)

载入模型

通过调用saver的restore()函数,从指定的路径找到模型文件,并覆盖到相关参数中。代码示例如下:

#初始化所有变量
init = tf.global_variable_initializer()

#定义saver和保存路径
saver = tf.train.Saver()
saverdir = "save_path"

#启动Session
with tf.Session() as sess:
    sess.run(init)
    #使用saver的restore方法载入模型
    print("x=0.2,z=",sess.run(z,feed_dict = {X:0.2}))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值