TensorFlow 提供两种模型格式:
- checkpoint,这是取决于创建模型的代码的格式。
- SavedModel,这是独立于创建模型的代码的格式。
保存部分训练的模型
Estimator自动将这些写入硬盘:
- checkpoints,即训练期间创建的模型的版本。
- event files,其中包含TensorBoard用于创建可视化的信息。
要将值分配给任何Estimator的构造函数的可选model_dir参数以指定Estimator存储其信息的顶级目录。 例如,以下代码将model_dir参数设置为models / iris目录:
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
假设调用评估器的训练方法。例如:
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
如下图所示,第一次训练调用将检查点和其他文件添加到model_dir目录中:
默认检查点文件目录
如果没有在构造函数里指定 model_dir,Estimator将检查点文件写入由Python的tempfile.mkdtemp函数选择的临时目录。
检查点保存频率
默认规则:
- 每600秒写一次checkpoint
- 当train 方法开始(第一次迭代)和完成(最终迭代)时写入检查点
- 只保留目录中最近的5个检查点。
也可以通过以下步骤修改:
- 创建一个定义所需计划的RunConfig对象。
- 在实例化Estimator时,将该RunConfig对象传递给Estimator的config参数。
例如以下代码:
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes.
keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints.
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
从文件恢复模型
第一次调用Estimator的train方法时,TensorFlow将检查点保存到model_dir。随后每次调用Estimator的train,eval或predict方法都会导致以下情况:
- Estimator通过运行model_fn() 来建立模型的图。
- Estimator根据最近检查点中存储的数据初始化新模型的权重
也就是说,一旦存在检查点,TensorFlow会在每次调用train(),evaluate()或predict()时重建模型。
避免错误恢复
从检查点恢复模型的状态只适用于模型和检查点是兼容的。例如,假设训练了一个包含两个隐藏层的DNNClassifier评估器,每个层都有10个节点:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
在训练之后(并且因此在models/iris中创建检查点之后),将每个隐藏层中的神经元数量从10更改为20,然后尝试重新训练模型:
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # Change the number of neurons in the model.
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
由于检查点中的状态与classifier2中描述的模型不兼容,重新训练失败,并出现以下错误:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]
要运行训练和比较稍有不同版本的模型的实验,保存创建每个model-dir的代码副本,可能需要为每个版本创建一个单独的git分支。 这种分离将保持检查点的可恢复性。