tensorflow模型代码框架

模型训练和评估的一个通用的代码框架

import tensorflow as tf

# 初始化变量和模型参数,定义训练闭环中的运算

def inference(X):
  # 计算推断模型在数据 X 上的输出,并将结果返回

def loss(X, Y):
  # 依据训练数据 X 及其期望输出 Y 计算损失

def inputs():
  # 读取或生成训练数据 X 及其期望输出 Y

def train(total_loss):
  # 依据计算的总损失训练或调整模型参数

def evaluate(sess, X, Y):
  # 对训练得到的模型进行评估

# 在一个会话中启动数据流图
with tf.Session() as sess:
  tf.initialize_all_variables().run()

  X, Y = inputs()

  total_loss = loss(X, Y)

  coord = tf.train.Coordinator()

  threads = tf.train.start_queue_runners(sess=sess, coord=coord)

  # 实际的训练迭代次数
  training_steps = 1000
  for step in range(training_steps):
    sess.run([train_op])
    # 出于调试和学习的目的,查看损失函数在训练工程中递减的情况
    if step % 10 == 0:
      print('loss:' + str(sess.run(total_loss)))

  evaluate(sess, X, Y)

  coord.request_stop()
  coord.join(threads)

以上便是模型训练和评估的基本代码框架。

训练模型,意味着通过许多个训练周期更新其参数(或者用 tensorflow 的语言来说,变量)。训练后,将模型进行保存,需要用到 tf.train.Saver 类,从而将数据流图中的变量保存到专门的二进制文件中。

我们应当周期地保存所有变量,创建检查点(checkpoint)文件,并在必要时从最近的检查点恢复训练。

# 模型定义代码
...
# 创建一个 Saver 对象
saver = tf.train.Saver()

# 在一个会话对象中启用数据流图,搭建流程
with tf.Session() as sess:

  # 模型设置 ...

  # 实际的训练闭环 ...
  for step in range(training_steps):
    sess.run([train_op])

    if step % 1000 = 0:
      saver.save(sess, 'my-model', golbal_step=step)

  # 模型评估 ...

  saver.save(sess, 'my-model', global_step=training_steps)

上述代码中,在开启会话对象之前实例化了一个 saver 对象,然后在训练闭环部分插入了几行代码,使得每次完成 1000 次训练迭代便调用一次 tf.train.Saver.save 方法,并在训练结束后,再次调用该方法。

如果希望从某个检查点恢复训练,则应使用 tf.train.get_checkpoint_state 方法,以验证之前是否有检查点文件被保存下来,而 tf.train.Saver.resotre 方法将负责恢复变量的值。

with tf.Session() as tf:
  # 模型设置...
  initial_step = 0
  # 验证之前是否已经保存了检查点文件
  ckpt = tf.train.get_checkpoint_state(os.path.dirname(__file__))
  if ckpt and ckpt.model_checkpoint_path:
    # 从检查点恢复模型参数
    saver.restore(sess, ckpt.model_check_path)
    initial_step = int(ckpt.model_checkpoint_path.rsplit('-', 1)[1])
  
  # 实际的闭环训练
  for step in range(initial_step, training_steps):
    ...

在上述代码中,首先检查是否有检查点文件存在,并在开始训练闭环前恢复各变量的值,还可依据检查点文件的名称恢复全局迭代次数。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值