tensorflow 检查点和模型,保存与恢复使用,官方教程(一)

  • 检查点:这种格式依赖于创建模型的代码。
  • SavedModel:这种格式与创建模型的代码无关。

示例代码

本文档依赖于 TensorFlow 使用入门中详细介绍的同一个鸢尾花分类示例。要下载和访问该示例,请执行下列两个命令:

 
git clone https://github.com/tensorflow/models/cd models/samples/core/get_started

本文档中的大部分代码段都是在 premade_estimator.py 的基础上加以细微修改。

保存经过部分训练的模型

Estimator 自动将以下内容写入磁盘:

  • 检查点:训练期间所创建的模型版本。
  • 事件文件:其中包含 TensorBoard 用于创建可视化图表的信息。

要指定 Estimator 在其中存储其信息的顶级目录,请为任何 Estimator 的构造函数的可选 model_dir 参数分配一个值。例如,以下代码将 model_dir 参数设置为 models/iris 目录:

classifier = tf.estimator.DNNClassifier(
    feature_columns
=my_feature_columns,
    hidden_units
=[10, 10],
    n_classes
=3,
    model_dir
='models/iris')

假设您调用 Estimator 的 train 方法。例如:

classifier.train(        input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),                steps=200)

如下图所示,第一次调用 train 会将检查点和其他文件添加到 model_dir 目录中:

第一次调用 train()。

要在基于 UNIX 的系统上查看创建的 model_dir 目录中的对象,只需按如下方式调用 ls 即可:

 
$ ls -1 models/iris
checkpoint
events
.out.tfevents.timestamp.hostname
graph
.pbtxt
model
.ckpt-1.data-00000-of-00001
model
.ckpt-1.index
model
.ckpt-1.meta
model
.ckpt-200.data-00000-of-00001
model
.ckpt-200.index
model
.ckpt-200.meta

上述 ls 命令展示了 Estimator 在第 1 步(训练开始)和第 200 步(训练结束)创建了检查点。

默认检查点目录

如果您未在 Estimator 的构造函数中指定 model_dir,则 Estimator 会将检查点文件写入由 Python 的 tempfile.mkdtemp 函数选择的临时目录中。例如,以下 Estimator 构造函数不会指定 model_dir 参数:

classifier = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[10, 10],    n_classes=3)print(classifier.model_dir)

tempfile.mkdtemp 函数会根据您的操作系统选择安全的临时目录。例如,Mac 操作系统上的典型临时目录可能如下所示:

/var/folders/0s/5q9kfzfj3gx2knj0vj8p68yc00dhcr/T/tmpYm1Rwa

检查点频率

默认情况下,Estimator 按照以下时间安排将检查点保存到 model_dir 中:

  • 每 10 分钟(600 秒)写入一个检查点。
  • 在 train 方法开始(第一次迭代)和完成(最后一次迭代)时写入一个检查点。
  • 只在目录中保留 5 个最近写入的检查点。

您可以通过执行下列步骤来更改默认时间安排:

  1. 创建一个 RunConfig 对象来定义所需的时间安排。
  2. 在实例化 Estimator 时,将该 RunConfig 对象传递给 Estimator 的 config 参数。

例如,以下代码将检查点时间安排更改为每 20 分钟保存一次并保留 10 个最近写入的检查点:

my_checkpointing_config = tf.estimator.RunConfig(    save_checkpoints_secs = 20*60,  # Save checkpoints every 20 minutes.    keep_checkpoint_max = 10,       # Retain the 10 most recent checkpoints.)classifier = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[10, 10],    n_classes=3,    model_dir='models/iris',    config=my_checkpointing_config)

恢复模型

第一次调用 Estimator 的 train 方法时,TensorFlow 会将一个检查点保存到 model_dir 中。随后每次调用 Estimator 的 trainevaluate 或 predict 方法时,都会发生下列情况:

  1. Estimator 通过运行 model_fn() 构建模型。(要详细了解 model_fn(),请参阅创建自定义 Estimator。)
  2. Estimator 根据最近写入的检查点中存储的数据来初始化新模型的权重。

换言之,如下图所示,一旦存在检查点,TensorFlow 就会在您每次调用 train()evaluate() 或 predict() 时重建模型。

对 train()、evaluate() 或 predict() 的后续调用

避免不当恢复

通过检查点恢复模型的状态这一操作仅在模型和检查点兼容时可行。例如,假设您训练了一个 DNNClassifier Estimator,它包含 2 个隐藏层且每层都有 10 个节点:

classifier = tf.estimator.DNNClassifier(
    feature_columns
=feature_columns,
    hidden_units
=[10, 10],
    n_classes
=3,
    model_dir
='models/iris')

classifier
.train(
    input_fn
=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps
=200)

在训练之后(因此已在 models/iris 中创建检查点),假设您将每个隐藏层中的神经元数量从 10 更改为 20,然后尝试重新训练模型:

classifier2 = tf.estimator.DNNClassifier(    feature_columns=my_feature_columns,    hidden_units=[20, 20],  # Change the number of neurons in the model.    n_classes=3,    model_dir='models/iris')classifier.train(    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),        steps=200)

由于检查点中的状态与 classifier2 中描述的模型不兼容,因此重新训练失败并出现以下错误:

...
InvalidArgumentError (see above for traceback): tensor_name =
dnn
/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does
not match the shape stored in checkpoint: [20]

要运行实验(训练并比较略微不同的模型版本),请保存一份创建每个 model_dir 所需的代码的副本,同时可能需要为每个版本创建一个单独的 git 分支。这种区分将有助于保证检查点的可恢复性。


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值