TensorFlow学习--TensorFlow模型的存储与恢复

TensorFlow模型的存储与恢复
最简单的保存和恢复模型的方法是使用tf.train.Saver对象.

模型的存储

用tf.train.Saver创建一个Saver来存储模型中的所有变量.

#!/usr/bin/python
# coding:utf-8

import tensorflow as tf
# 定义两个常量Variable
v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")
# 变量初始化
init_op = tf.initialize_all_variables()

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    save_path = saver.save(sess, "model/model.ckpt")
    print "Model saved in file:", save_path

输出:

Model saved in file: model/model.ckpt

可以在model目录下看到:

这里写图片描述
变量存储在二进制文件里,主要包含从变量到tensor值的映射关系.

模型的恢复

用同一个Saver对象来恢复变量.
当从文件中恢复变量时,不需要事先对变量进行初始化.

#!/usr/bin/python
# coding:utf-8

import tensorflow as tf
v1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")

# 当从文件中恢复变量时,不需要事先初始化
# init_op = tf.initialize_all_variables()

saver = tf.train.Saver()
with tf.Session() as sess:
    # sess.run(init_op)
    saver.restore(sess, "model/model.ckpt")
    print "Model:"
    print v1.eval()
    print v2.eval()

输出:

Model:
[ 1.  1.  1.]
[ 2.  2.  2.  2.  2.]

指定变量存储与恢复

如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量.
通过给tf.train.Saver()传入python字典或列表,来保持变量及其对应的名称:键对应使用的名称,值对应被管理的变量.

传入字典

存储

#!/usr/bin/python
# coding:utf-8

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")

init_op = tf.initialize_all_variables()

# 如果不给tf.train.Saver()传入任何参数,则saver将处理graph中的所有变量
saver = tf.train.Saver({"variable_v1":v1})
with tf.Session() as sess:
    sess.run(init_op)
    save_path = saver.save(sess, "model/model_v1.ckpt")
    print "Model saved in file:", save_path

输出:

Model saved in file: model/model_v1.ckpt

可以在model目录下看到:

这里写图片描述

恢复

#!/usr/bin/python
# coding:utf-8

import tensorflow as tf
v1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")

saver = tf.train.Saver({"variable_v1":v1})
with tf.Session() as sess:
    # sess.run(init_op)
    saver.restore(sess, "model/model_v1.ckpt")
    print "Model v1:"
    print v1.eval()
    # 或使用sess.run(v1)
    # print sess.run(v1)

输出:

Model v1:
[ 1.  1.  1.]
传入列表

存储

import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")

init_op = tf.initialize_all_variables()
saver = tf.train.Saver([v1, v2])
with tf.Session() as sess:
    sess.run(init_op)
    saver.save(sess, "model/model_v1v2.ckpt")

恢复

import tensorflow as tf
v1 = tf.Variable(tf.constant(0.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(0.0, shape=[5]), name="v2")

saver = tf.train.Saver([v1])
with tf.Session() as sess:
    saver.restore(sess, "model/model_v1v2.ckpt")
    print sess.run(v1)

输出:

[ 1.  1.  1.]

创建多个saver对象

需要保存和恢复变量的不同子集时可以创建任意多个saver对象.

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0, shape=[3]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[5]), name="v2")

init_op = tf.initialize_all_variables()

saver1 = tf.train.Saver({"variable_v1":v1})
saver2 = tf.train.Saver({"variable_v2":v2})
with tf.Session() as sess:
    sess.run(init_op)
    saver1.save(sess, "model/model_v1.ckpt")
    saver2.save(sess, "model/model_v2.ckpt")

可以在model目录下看到:

这里写图片描述

同一个变量也可被列入多个saver对象中,只有saver的restore()函数被运行时它的值才会被改变.

完整示例

模型存储

创建一个简单的TensorFlow模型用于二维数据的线性回归.定义一个Saver对象,并且在train_graph()方法中,通过100次迭代来最小化损失函数.然后,模型在每次迭代中以及优化完成后保存到本地磁盘.每次保存都会在磁盘上创建名为“checkpoint”的二进制文件.

#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

tf.reset_default_graph()
# 为x和y点创建占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 初始化需要学习的两个参数
h_est = tf.Variable(0.0, name='hor_estimate')
v_est = tf.Variable(0.0, name='ver_estimate')
# y_est保存y轴上的估计值
y_est = tf.square(X - h_est) + v_est
# 将损失函数定义为Y和y_est之间的平方距离
cost = (tf.pow(Y - y_est, 2))
# 最小化损失函数,学习率为0.001
trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
# 水平和垂直方向进行移动
h = 1
v = -2
# 在训练数据中添加噪音
x_train = np.linspace(-2,4,201)
noise = np.random.randn(*x_train.shape) * 0.4
y_train = (x_train - h) ** 2 + v + noise
# 创建一个Saver对象
saver = tf.train.Saver()
init = tf.global_variables_initializer()
# 迭代100次
def train_graph():
    with tf.Session() as sess:
        sess.run(init)
        for i in range(100):
            for (x, y) in zip(x_train, y_train):
                # 将实际数据传入
                sess.run(trainop, feed_dict={X: x, Y: y})
                # print(x,y)
            # 在每次迭代中创建一个检查点
            saver.save(sess, './model_iter', global_step=i)
        # 保存最终模型
        saver.save(sess, './model_final')
        h_ = sess.run(h_est)
        v_ = sess.run(v_est)
    return h_, v_

if __name__=="__main__":
    result = train_graph()
    print("h_est = %.2f, v_est = %.2f" % result)
    # 可视化数据
    plt.rcParams['figure.figsize'] = (10, 6)
    plt.scatter(x_train, y_train)
    plt.xlabel('x_train')
    plt.ylabel('y_train')
    plt.show()

这里写图片描述

保存模型时,有4种类型的文件来保存它:

  • “.meta”文件:包含图结构.
  • “.data”文件:包含变量的值.
  • “.index”文件:标识检查点.
  • “checkpoint”文件:包含最近检查点列表的protocol buffer.

这里写图片描述

Saver构造函数的其他一些参数可以控制整个过程:
max_to_keep:保留的最大检查点数量;
keep_checkpoint_every_n_hours:保存检查点的时间间隔.

模型恢复

在下面的例子中加载模型,并打印出两个系数的数值h_est和v_est:

#!/usr/bin/python
# coding:utf-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

tf.reset_default_graph()
# 为x和y点创建占位符
X = tf.placeholder("float")
Y = tf.placeholder("float")
# 初始化需要学习的两个参数
h_est = tf.Variable(0.0, name='hor_estimate')
v_est = tf.Variable(0.0, name='ver_estimate')
# y_est保存y轴上的估计值
y_est = tf.square(X - h_est) + v_est
# 将损失函数定义为Y和y_est之间的平方距离
cost = (tf.pow(Y - y_est, 2))
# 最小化损失函数,学习率为0.001
trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost)
# 水平和垂直方向进行移动
h = 1
v = -2
# 在训练数据中添加噪声
x_train = np.linspace(-2,4,201)
noise = np.random.randn(*x_train.shape) * 0.4
y_train = (x_train - h) ** 2 + v + noise
tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("./model_final.meta")

with tf.Session() as sess:
    # 恢复
    imported_meta.restore(sess, tf.train.latest_checkpoint('./'))
    h_est2 = sess.run('hor_estimate:0')
    v_est2 = sess.run('ver_estimate:0')
    print("h_est: %.2f, v_est: %.2f" % (h_est2, v_est2))

plt.scatter(x_train, y_train, label='train data')
plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red', label='model')
plt.xlabel('x_train')
plt.ylabel('y_train')
plt.legend()
plt.show()

这里写图片描述


参考:

tf.train.Saver

tf.train.import_meta_graph

TensorFlow: Save and Restore Models

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值