Tensorflow入门(六)——模型持久化

tensorflow提供了tf.train.saver类来保存还原一个神经网络模型。

1.保存计算图

以下为保存计算图的方法:

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
# 声明tf.train.saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
     sess.run(init_op)
     saver.save(sess, "./checkpoint_dir/MyModel")

然后会生成三个文件。
第一个为model.ckpt.meta,保存了tensorflow计算图的结构,即网络结构。
第二个为model.ckpt,保存了Tensorflow中每一个变量的取值。
第三个为checkpoint文件,保存了一个目录下所有的模型文件列表。
以下为加载方法:

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

saver = tf.train.Saver()
with tf.Session() as sess:

    saver.restore(sess, "./checkpoint_dir/MyModel")
    print(sess.run(result))

2.变量重命名

在保存和加载时给变量重命名,使用字典:

saver = tf.train.Saver({"v1":v1,"v2":v2})

从而方便使用滑动平均值(将影子变量映射到变量自身,从而不需要再次调用函数计算)

# 保存滑动平均模型,从而不需要再次计算
v = tf.Variable(0, dtype=tf.float32, name="v")

for variables in tf.global_variables():
    print(variables.name)
    #输出v:0
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
    print(variables.name)
    #输出v/ExponentialMovingAverage:0

saver = tf.train.Saver()
with tf.Session() as sess:
    ini_op = tf.global_variables_initializer()
    sess.run(ini_op)

    sess.run(tf.assign(v, 10))
    sess.run(maintain_averages_op)
    saver.save(sess, "./checkpoint_dir/MyModel01")
    print(sess.run([v, ema.average(v)]))
    #输出[10.0, 0.099999905]
    

读取变量的滑动平均值:

 v = tf.Variable(0, dtype=tf.float32, name="v")
 ema = tf.train.ExponentialMovingAverage(0.99)
# variables_to_restore生成字典来通过变量重命名直接读取变量的滑动平均值
 saver = tf.train.Saver(ema.variables_to_restore())
 with tf.Session() as sess:

    saver.restore(sess, "./checkpoint_dir/MyModel")
    print(sess.run(v))

variables_to_restore生成字典={“v/ExponentialMovingAverage”: v}

3.将Tensorflow程序放在一个文件中

在测试或离线预测时,只需要知道如何从神经网络输入层计算到输入层,不需要变量初始化,模型保存等辅助信息,而convert_variables_to_constants函数将计算图中的变量及取值通过常量的方式保存到一个文件中,
实例如下:

import tensorflow as tf
from tensorflow.python.framework import graph_util

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    # 导出计算图的GraphDef部分,这一部分即可完成从输入层到输出层的计算
    graph_def = tf.get_default_graph().as_graph_def()
    # 将变量及取值转化为常量,并将不必要的节点去掉。'add'为需要保存的节点名称
    output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
    with tf.gfile.GFile("./simple_model", "wb")as f:
        f.write(output_graph_def.SerializeToString())

以下程序可直接得到定义的加法运算的结果,适用于只需要某个节点的值时,
在迁移学习中将得到应用。

import tensorflow as tf
from tensorflow.python.platform import gfile

with tf.Session() as sess:
    model_filename = "./combined model"
    with gfile.FastGFile(model_filename, 'rb') as f:
        graph_def = tf.GraphDef()  # 读取
        graph_def.ParseFromString(f.read())  # 解析成PB文件
    # ["add:0"]是张量的名称
    result = tf.import_graph_def(graph_def, return_elements=["add:0"])
    print(sess.run(result))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值