【学习笔记】get_checkpoint_state()学习

本文介绍了在TensorFlow中使用get_checkpoint_state()进行断点续训的学习过程。通过分析该函数的输入、输出及内部工作原理,理解如何从checkpoint目录恢复模型,并探讨latest_filename参数的含义及其对训练的影响。同时,提到了os.path.join()在路径拼接中的应用。
摘要由CSDN通过智能技术生成

【学习笔记】get_checkpoint_state()学习
1 引言
最近复现了第一个项目,由于项目需要训练的时间非常长,所以开始考虑用断点续训的方法来进行分段训练,找了一通之后发现,其实源代码中是有写的,更改了了checkpoint的地址后,每次还是重头开始训练,查看了checkpoint的文件夹每隔5个eproch保存了一个checkpoint,推测应该是还需要在源代码中设置开始训练的checkpoint的文件名称,下面针对get_checkpoint()开始学习。

2 总体说明
具体代码如下


@tf_export("train.get_checkpoint_state")
def get_checkpoint_state(checkpoint_dir, latest_filename=None):
  """Returns CheckpointState proto from the "checkpoint" file.

  If the "checkpoint" file contains a valid CheckpointState
  proto, returns it.

  Args:
    checkpoint_dir: The directory of checkpoints.
    latest_filename: Optional name of the checkpoint file.  Default to
      'checkpoint'.

  Returns:
    A CheckpointState if the state was available, None
    otherwise.

  Raises:
    ValueError: if the checkpoint read doesn't have model_checkpoint_path set.
  """
  ckpt = None
  coord_checkpoint_filename = _GetCheckpointFilename(checkpoint_dir,
                                                     latest_filename)
  f = None
  try:
    # Check that the file exists before opening it to avoid
    # many lines of errors from colossus in the logs.
    if file_io.
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值