TensorFlow模型持久化(保存、加载)

1.保存TensorFlow模型:

import tensorflow as tf
saver = tf.train.Saver()

with tf.Session() as sess:
    ...
    saver.save(sess, path)

path是保存模型的路径及文件名,一般是以.ckpt为后缀,保存完会出现三个文件,一个是model.ckpt.meta,保存的是TensorFlow的计算图结构,第二个是model.ckpt,保存了TensorFlow每一个变量的取值,第三个是checkpoint文件,保存了一个目录下所有模型文件列表。

2.加载模型

path是指的保存模型路径下的model.ckpt文件,加载模型不用运行变量初始化过程,而是将变量的值通过已经保存的模型加载进来,并将读取到的变量名称对之前声明的变量名称进行覆盖:

import tensorflow as tf
saver = tf.train.Saver()

with tf.Session() as sess:
    saver.restore(sess, path)

如果不希望重复定义图上运算,也可以直接加载已经持久化的图。可通过张量的名称获取张量:

sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))

保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表指定需要保存和加载的变量:

tf.train.Saver([v1])

TensorFlow的tf.train.Saver类也支持在保存或者加载变量时给变量重命名,通过字典将保存时的变量名和需要加载的变量联系起来:

saver = tf.Saver({"v1":v1, "v2":v2})

如果声明的变量名称和已经保存的模型中的变量的名称不同,直接读取会报错,所以要通过重命名来加载模型中的变量:

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")
#此时如果直接使用tf.train.Saver()来加载模型会报变量找不到的错误
saver = tf.train.Saver({"v1":v1, "v2": v2})
#会将原来名为v1的变量加载到v1变量(名字是other-v1)中,v2同理

滑动平均值在模型中的保存:

v = tf.Variable(0, dtype=tf.float32, name="v")

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
#声明滑动平均模型后,会自动生成一个影子变量:v/ExponentialMovingAverage,此时变量有v:0和v/ExponentialMovingAverage
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.gloabal_variables_initializer()
    sess.run(init_op)
    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    saver.save(sess, path)
    #保存时,TensorFlow会自动将v:0和v/ExponentialMovingAverage都保存下来

重命名滑动平均值的加载:

v = tf.Variable(0, dtype=tf.float32, name="v")
ema = tf.train.ExponentialMovingAverage(0.99)
#tf.train.ExponentialMovingAverage类提供了variables_to_restore函数来生成tf.train.Saver类所需要的重命名字典
saver = tf.train.saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess, path)

将变量和计算图结构分成不同的文件存储时也不方便,可以使用TensorFlow的convert_variables_to_constants函数,将计算图中的变量及其取值通过常量的方式保存,这样TensorFlow整个计算图可以统一放在一个文件中。

from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0 ,shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2
init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    graph_def = tf.get_default_graph().as_graph_def()
    #将图中变量保存为常量,同时将图中不必要的节点去掉(比如变量初始化操作)
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])#add是需要保存节点的名称
    #节点名称为add,张量为add:0
    #将导出的模型存入文件
    with tf.gfile.GFile("model/model2.pb", "wb") as f:
        f.write(output_graph_def.SerializeToString())

当只需要计算图中某个节点的取值时,可使用下面方法,在迁移学习中十分有用:

from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_file = "model/model2.pb"
    #读取模型,并将文件解析成对应的GraphDef Protocol Buffer
    with gfile.FastGFile(model_file, 'rb') as f:
        graph_def = tf.Graph()
        graph_def.ParseFromString(f.read())
    #在保存的时候是保存的节点名,所以为add,加载时取出的是张量名称,所以是add:0
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值