模型保存:saver = tf.train.Saver() saver.save(sess, ‘./model/model.ckpt’)
tensorflow模型保存后会产生四个文件:
checkpoint: 简介文件
model.ckpt.index 映射关系(保存了辅助索引信息)
model.ckpt.metel 原数据信息(保存了当前图结构)
model.ckpt.data-00000-of-00001具体的数据(保存了当前参数名和值)
在tensorflow中checkpoints文件是一个二进制文件,用于存储所有的weights,biases,gradients和其他variables的值。.meta文件则用于存储 graph中所有的variables, operations, collections等。简言之一个存储参数,一个存储图。
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。
注意:tensorflow模型,如果一个模型保存了好几次,后面的模型会覆盖掉前面的模型
模型加载即模型提取,两种方式:
1.完整提取
2.给定变量名的映射
# 模型保存和加载
v1 = tf.Variable(tf.constant(3.0), name='v1')
v2 = tf.Variable(tf.constant(4.0), name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 模型保存到result文件夹下,文件前缀为:model.ckpt
saver.save(sess, './model/model.ckpt')
'''
'''
# 模型的提取一(完整提取:需要完整恢复保存之前的数据格式)
v1 = tf.Variable(tf.constant(1.0), name='v1')
v2 = tf.Variable(tf.constant(4.0), name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
# 会从对应的文件夹中加载变量、图等相关信息
saver.restore(sess, './model/model.ckpt')
print(sess.run([result])) # 注意v1现在改为1.0,result结果仍然是7.0
'''
'''
# 直接加载图, 不需要定义变量了
saver = tf.train.import_meta_graph('./model/model.ckpt.meta') # 加载了原数据
with tf.Session() as sess:
saver.restore(sess, "./model/model.ckpt")
# 注意:tensorflow1.14.0版本no attribute 'get_tensor_by_name',所以下面现在不能运行
# print(sess.run(tf.get_default_graph.get_tensor_by_name('add:0')))
# 此处直接写result会出错,所以改为另一种方式
# 找到原图,然后依据名称找张量。在原图中,v1叫v1,v2叫v2,result叫add,又因为是第一个加,所以是add:0(这样写是怕如果有多个同样的操作,能够区分开)
# print(sess.run(tf.get_default_graph.get_tensor_by_name('v1:0')))
'''
'''
# 模型的提取二(加载的时候,给定变量名的映射)
# 开发模型的是一个人,用的又是另一个人,可能有两个模型,都有v变量,你要使用其中一个模型,那就可以用变量映射,以免混淆???不懂。反正现在知道它也是想使用那个模型了,也是一种方式
a = tf.Variable(tf.constant(1.0), name='a')
b = tf.Variable(tf.constant(4.0), name='b')
result = a + b
saver = tf.train.Saver({'v1': a, 'v2': b}) # v1为模型保存的变量名称,a为现在代码的变量名称
with tf.Session() as sess:
# 会从对应的文件夹中加载变量、图等相关信息
saver.restore(sess, './model/model.ckpt')
print(sess.run([result])) # [7.0]
print(sess.run([a])) # [3.0]
'''