tensorflow模型保存和加载

本文详细介绍了在TensorFlow中如何保存和加载模型,包括模型保存产生的四个文件的作用,以及模型加载的两种方式:完整提取和给定变量名的映射。通过实例展示了如何使用tf.train.Saver进行模型的保存和加载。
摘要由CSDN通过智能技术生成

模型保存:saver = tf.train.Saver() saver.save(sess, ‘./model/model.ckpt’)

tensorflow模型保存后会产生四个文件:
checkpoint: 简介文件
model.ckpt.index 映射关系(保存了辅助索引信息)
model.ckpt.metel 原数据信息(保存了当前图结构)
model.ckpt.data-00000-of-00001具体的数据(保存了当前参数名和值)

在tensorflow中checkpoints文件是一个二进制文件,用于存储所有的weights,biases,gradients和其他variables的值。.meta文件则用于存储 graph中所有的variables, operations, collections等。简言之一个存储参数,一个存储图。

著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
运行后,得到的模型,截屏展示一下

注意:tensorflow模型,如果一个模型保存了好几次,后面的模型会覆盖掉前面的模型

模型加载即模型提取,两种方式:

1.完整提取
2.给定变量名的映射
# 模型保存和加载
v1 = tf.Variable(tf.constant(3.0), name='v1')
v2 = tf.Variable(tf.constant(4.0), name='v2')
result = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    # 模型保存到result文件夹下,文件前缀为:model.ckpt
    saver.save(sess, './model/model.ckpt')
'''

'''
# 模型的提取一(完整提取:需要完整恢复保存之前的数据格式)
v1 = tf.Variable(tf.constant(1.0), name='v1')
v2 = tf.Variable(tf.constant(4.0), name='v2')
result = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess:
    # 会从对应的文件夹中加载变量、图等相关信息
    saver.restore(sess, './model/model.ckpt')
    print(sess.run([result]))  # 注意v1现在改为1.0,result结果仍然是7.0
'''

'''
# 直接加载图, 不需要定义变量了
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')  # 加载了原数据

with tf.Session() as sess:
    saver.restore(sess, "./model/model.ckpt")

    # 注意:tensorflow1.14.0版本no attribute 'get_tensor_by_name',所以下面现在不能运行

    # print(sess.run(tf.get_default_graph.get_tensor_by_name('add:0')))
    # 此处直接写result会出错,所以改为另一种方式
    # 找到原图,然后依据名称找张量。在原图中,v1叫v1,v2叫v2,result叫add,又因为是第一个加,所以是add:0(这样写是怕如果有多个同样的操作,能够区分开)

    # print(sess.run(tf.get_default_graph.get_tensor_by_name('v1:0')))
'''

'''

# 模型的提取二(加载的时候,给定变量名的映射)
# 开发模型的是一个人,用的又是另一个人,可能有两个模型,都有v变量,你要使用其中一个模型,那就可以用变量映射,以免混淆???不懂。反正现在知道它也是想使用那个模型了,也是一种方式
a = tf.Variable(tf.constant(1.0), name='a')
b = tf.Variable(tf.constant(4.0), name='b')
result = a + b

saver = tf.train.Saver({'v1': a, 'v2': b})  # v1为模型保存的变量名称,a为现在代码的变量名称
with tf.Session() as sess:
    # 会从对应的文件夹中加载变量、图等相关信息
    saver.restore(sess, './model/model.ckpt')
    print(sess.run([result]))  # [7.0]
    print(sess.run([a]))  # [3.0]
'''
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值