模型训练和评估的一个通用的代码框架
import tensorflow as tf
# 初始化变量和模型参数,定义训练闭环中的运算
def inference(X):
# 计算推断模型在数据 X 上的输出,并将结果返回
def loss(X, Y):
# 依据训练数据 X 及其期望输出 Y 计算损失
def inputs():
# 读取或生成训练数据 X 及其期望输出 Y
def train(total_loss):
# 依据计算的总损失训练或调整模型参数
def evaluate(sess, X, Y):
# 对训练得到的模型进行评估
# 在一个会话中启动数据流图
with tf.Session() as sess:
tf.initialize_all_variables().run()
X, Y = inputs()
total_loss = loss(X, Y)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
# 实际的训练迭代次数
training_steps = 1000
for step in range(training_steps):
sess.run([train_op])
# 出于调试和学习的目的,查看损失函数在训练工程中递减的情况
if step % 10 == 0:
print('loss:' + str(sess.run(total_loss)))
evaluate(sess, X, Y)
coord.request_stop()
coord.join(threads)
以上便是模型训练和评估的基本代码框架。
训练模型,意味着通过许多个训练周期更新其参数(或者用 tensorflow 的语言来说,变量)。训练后,将模型进行保存,需要用到 tf.train.Saver 类,从而将数据流图中的变量保存到专门的二进制文件中。
我们应当周期地保存所有变量,创建检查点(checkpoint)文件,并在必要时从最近的检查点恢复训练。
# 模型定义代码
...
# 创建一个 Saver 对象
saver = tf.train.Saver()
# 在一个会话对象中启用数据流图,搭建流程
with tf.Session() as sess:
# 模型设置 ...
# 实际的训练闭环 ...
for step in range(training_steps):
sess.run([train_op])
if step % 1000 = 0:
saver.save(sess, 'my-model', golbal_step=step)
# 模型评估 ...
saver.save(sess, 'my-model', global_step=training_steps)
上述代码中,在开启会话对象之前实例化了一个 saver 对象,然后在训练闭环部分插入了几行代码,使得每次完成 1000 次训练迭代便调用一次 tf.train.Saver.save 方法,并在训练结束后,再次调用该方法。
如果希望从某个检查点恢复训练,则应使用 tf.train.get_checkpoint_state 方法,以验证之前是否有检查点文件被保存下来,而 tf.train.Saver.resotre 方法将负责恢复变量的值。
with tf.Session() as tf:
# 模型设置...
initial_step = 0
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(os.path.dirname(__file__))
if ckpt and ckpt.model_checkpoint_path:
# 从检查点恢复模型参数
saver.restore(sess, ckpt.model_check_path)
initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])
# 实际的闭环训练
for step in range(initial_step, training_steps):
...
在上述代码中,首先检查是否有检查点文件存在,并在开始训练闭环前恢复各变量的值,还可依据检查点文件的名称恢复全局迭代次数。