tensorflow保存及载入模型、添加检查点

在训练完模型之后,就要把模型保存起来,方便以后使用。

保存模型 save():
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(...)
    saver.save(sess, savePath/fileName)
载入模型 restore() :
saver = tf.train.Saver()
with tf.Session() as sess1:
    sess.run(...)
    saver.restore(sess1, savePath/fileName)
保存模型时也可以指定变量名字与变量的对应关系:
1)saver = tf.train.Saver({key1: valve, key2: value})
   例: saver = tf.train.Saver({'weights': w, biases: b})
2)saver = tf.train.Saver([w, b])  # 放到list中
3)saver = tf.train.Saver(v.op.name: v for v in [w, b])  # 将op的名字当做key
打印模型内容:
print_tensors_in_checkpoint_file(save_dir+'linerModel.cpkt', None, True)

下面是一个保存及载入模型的完整例子:

import numpy as np
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# 构建实验数据
train_x = np.linspace(-1, 1, 100)
# y = 2 * x + b
train_y = 2. * train_x + np.random.randn(*train_x.shape) * 0.3

# 创建模型
# 占位符
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
# 模型参数
weights = tf.Variable(tf.random_normal([1]), name='weights')
biases = tf.Variable(tf.zeros([1]), name='biases')
z = tf.multiply(X, weights) + biases

# 构建损失函数
loss = tf.reduce_mean(tf.square(Y - z))
# 定义学习率
learning_rate = 0.01
# 构建优化函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 最小化损失函数
train = optimizer.minimize(loss)

# 初始化所有变量
init = tf.global_variables_initializer()

# 定义 epochs
training_epochs = 20
# 每隔两步显示一次中间值
display_step = 2

# 存放批次值和损失值
plot_data = {'batchsize': [], 'loss': []}


# 定义保存模型对象
saver = tf.train.Saver()
save_dir = 'logs/'  # 生成模型的路径
# 启动Session
with tf.Session() as sess:
    # 初始化全局变量
    sess.run(init)

    # 向模型中 feed 数据
    for epoch in range(training_epochs):
        for (x, y) in zip(train_x, train_y):
            feed_dict = {X: x, Y: y}
            sess.run(train, feed_dict=feed_dict)

        # 显示训练中的数据
        if epoch % display_step == 0:
            loss_ = sess.run(loss, feed_dict={X: train_x, Y: train_y})
            print('epoch:', epoch + 1, 'loss = ', loss_, 'weights=',
                  sess.run(weights), 'biases=', sess.run(biases))

    print('Finished...')
    # 保存模型
    saver.save(sess, save_dir+'linerModel.cpkt')  # 如果指定的文件夹不存在会自动创建
    print('loss=', sess.run(loss, feed_dict={X: train_x, Y: train_y}), 'weights=',
          sess.run(weights), 'biases=', sess.run(biases))


# 使用模型
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linerModel.cpkt')
    print('下面是模型载入结果: ')
    print('x=0.2, z=', sess_2.run(z, feed_dict={X: 0.2}))


# 打印文件内容
print_tensors_in_checkpoint_file(save_dir+'linerModel.cpkt', None, True)
添加检查点 Checkpoint
在训练之中,难免会出现中断的情况,这时就设置一个检查点。
saver = tf.train.Saver(max_to_keep=1)  # 生成saver
with tf.Session() as sess1:
    sess.run(...)
    saver.save(sess1, savePath/fileName, global_step=epoch)
max_to_keep 参数指定最多生成多少个检查点文件
 载入检查点
load_epoch = 18           #  只是文件的一个后缀,可以根据需要修改
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linearModel.cpkt-' + str(load_epoch))
另一种添加检查点的方式: trainMonitored
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='logs/ckpt',
                                       save_checkpoint_secs=2) as sess:
    print(sess.run([global_step]))
    while not sess.should_stop():  # 启用死循环,session不停止就不结束
        i = sess.run(step)
        print(i)
如果不设置 save_checkpoint_secs 参数,默认时间是10mins,
该方法必须定义global_step,不然报错

下面是添加及载入检查点的完整例子:

import numpy as np
import tensorflow as tf

# 构建实验数据
train_x = np.linspace(-1, 1, 100)
# y = 2 * x + b
train_y = 2. * train_x + np.random.randn(*train_x.shape) * 0.3

# 创建模型
# 占位符
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
# 模型参数
weights = tf.Variable(tf.random_normal([1]), name='weights')
biases = tf.Variable(tf.zeros([1]), name='biases')
z = tf.multiply(X, weights) + biases

# 构建损失函数
loss = tf.reduce_mean(tf.square(Y - z))
# 定义学习率
learning_rate = 0.01
# 构建优化函数
optimizer = tf.train.GradientDescentOptimizer(learning_rate)
# 最小化损失函数
train = optimizer.minimize(loss)

# 初始化所有变量
init = tf.global_variables_initializer()

# 定义 epochs
training_epochs = 20
# 每隔两步显示一次中间值
display_step = 2

# 存放批次值和损失值
plot_data = {'batchsize': [], 'loss': []}


# 定义保存模型对象
saver = tf.train.Saver(max_to_keep=2)
save_dir = 'logs/'  # 生成模型的路径
# 启动Session
with tf.Session() as sess:
    # 初始化全局变量
    sess.run(init)

    # 向模型中 feed 数据
    for epoch in range(training_epochs):
        for (x, y) in zip(train_x, train_y):
            feed_dict = {X: x, Y: y}
            sess.run(train, feed_dict=feed_dict)

        # 显示训练中的数据
        if epoch % display_step == 0:
            loss_ = sess.run(loss, feed_dict={X: train_x, Y: train_y})
            print('epoch:', epoch + 1, 'loss = ', loss_, 'weights=',
                  sess.run(weights), 'biases=', sess.run(biases))
            # 保存检查点
            saver.save(sess, save_dir + 'linearModel.cpkt', global_step=epoch)
    print('Finished...')
    # 保存模型
    saver.save(sess, save_dir+'linerModel.cpkt')  # 如果指定的文件夹不存在会自动创建
    print('loss=', sess.run(loss, feed_dict={X: train_x, Y: train_y}), 'weights=',
          sess.run(weights), 'biases=', sess.run(biases))

# 载入检查点
load_epoch = 18
with tf.Session() as sess_2:
    saver.restore(sess_2, save_dir+'linearModel.cpkt-' + str(load_epoch))
    print('下面是检查点的结果: ')
    print('x=0.2, z=', sess_2.run(z, feed_dict={X: 0.2}))

# trainMonitored
tf.reset_default_graph()
global_step = tf.train.get_or_create_global_step()
step = tf.assign_add(global_step, 1)
with tf.train.MonitoredTrainingSession(checkpoint_dir='logs/ckpt',
                                       save_checkpoint_secs=2) as sess:
    print(sess.run([global_step]))
    while not sess.should_stop():  # 启用死循环,session不停止就不结束
        i = sess.run(step)
        print(i)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值