本文介绍了 Estimators 模型的保存和恢复。
TensorFlow提供了两种模型格式:
- checkpoints:这种格式依赖于创建模型的代码。
- SavedModel:这种格式与创建模型的代码无关。
本文档主要介绍checkpoints。要详细了解 SavedModel
,请参阅《TensorFlow 编程人员指南》的 Saving and Restoring 一章。
1. 保存经过部分训练的模型
Estimators 在训练过程中会自动将以下内容保存到磁盘:
- chenkpoints:训练过程中的模型快照。
- event files:其中包含 TensorBoard 用于创建可视化图表的信息。
通过 model_dir
参数,我们可以指定 Estimator 保存上述文件时的顶级目录。
# 实例化 estimator
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
# 训练 estimator
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
如下图所示,第一次调用 train
方法会将 checkpoints 和 event files 文件添加到 model_dir