【学习笔记】get_checkpoint_state()学习
1 引言
最近复现了第一个项目,由于项目需要训练的时间非常长,所以开始考虑用断点续训的方法来进行分段训练,找了一通之后发现,其实源代码中是有写的,更改了了checkpoint的地址后,每次还是重头开始训练,查看了checkpoint的文件夹每隔5个eproch保存了一个checkpoint,推测应该是还需要在源代码中设置开始训练的checkpoint的文件名称,下面针对get_checkpoint()开始学习。
2 总体说明
具体代码如下
@tf_export("train.get_checkpoint_state")
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
"""Returns CheckpointState proto from the "checkpoint" file.
If the "checkpoint" file contains a valid CheckpointState
proto, returns it.
Args:
checkpoint_dir: The directory of checkpoints.
latest_filename: Optional name of the checkpoint file. Default to
'checkpoint'.
Returns:
A CheckpointState if the state was available, None
otherwise.
Raises:
ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
"""
ckpt = None
coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
latest_filename)
f = None
try:
# Check that the file exists before opening it to avoid
# many lines of errors from colossus in the logs.
if file_io.