TensorFlow2官方指南 :训练检查点Checkpoint 指南

本文详述了TensorFlow2中检查点Checkpoint的使用,包括如何编写和读取检查点,以及如何利用tf.train.Checkpoint、tf.keras.Model.save_weights进行模型的保存与恢复,特别强调了在训练过程中的应用及其加载机制。
摘要由CSDN通过智能技术生成

本文介绍用于编写和读取检查点的 API


训练检查点指南

检查点可以捕获模型使用的所有参数(tf.Variable 对象)的确切值。检查点不包含对模型所定义计算的任何描述,因此通常仅在将使用保存参数值的源代码可用时才有用。

编写和读取检查点的 API。

import tensorflow as tf

class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)


net = Net()

tf.keras.Model.save_weights 可以保存一个 TensorFlow 检查点。

net.save_weights('easy_checkpoint')

编写检查点

TensorFlow 模型的持久状态存储在 tf.Variable 对象中。这些对象可以直接构造,但通常会通过像 tf.keras.layers 或 tf.keras.Model 这样的高级 API 创建。

管理变量的最简单方法是将它们附加到 Python 对象,然后引用这些对象。

tf.train.Checkpoint、tf.keras.layers.Layer 和 tf.keras.Model 的子类会自动跟踪分配给其特性的变量。下面的示例构造了一个简单的线性模型,然后编写检查点,其中包含该模型所有变量的值。

您可以使用 Model.save_weights 轻松保存模型检查点。

tf.train.Checkpoint

为了帮助演示 tf.train.Checkpoint 的所有功能, 下面定义了一个小数据集和优化步骤:

def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值