pytorch学习(一)pytorch中的断点续训

1. 设置断点续训的目的

在遇到停电宕机,设备内存不足导致实验还没有跑完的情况下,如果没有使用断点续训,就需要从头开始训练,耗时费力。
断点续训主要保存的是网络模型的参数以及优化器optimizer的状态(因为很多情况下optimizer的状态会改变,比如学习率的变化)

2. 设置断点续训的方法

  1. 参数设置
    resume: 是否进行续训
    initepoch: 进行续训时的初始epoch
  2. checkpoint载入过程(这部分操作放在epoch循环前边)
resume = True      # 设置是否需要从上次的状态继续训练
    if resume:
        if os.path.isfile("results/{}_model.pth".format(save_name_pre)):
            print("Resume from checkpoint...")
            checkpoint = torch.load("results/{}_model.pth".format(save_name_pre))
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            initepoch = checkpoint['epoch'] + 1
            print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
        else:
            print("====>no checkpoint found.")
            initepoch = 1   # 如果没进行训练过,初始训练epoch值为1
  1. 每一轮,checkpoint的存储过程,保存模型参数,优化器参数,轮数(这部分操作放在epoch循环里边)
# 保存断点
        if test_acc_1 > best_acc:
            best_acc = test_acc_1
            checkpoint = {"model_state_dict": model.state_dict(),
                          "optimizer_state_dict": optimizer.state_dict(),
                          "epoch": epoch}
            path_checkpoint = "results/{}_model.pth".format(save_name_pre)
            torch.save(checkpoint, path_checkpoint)
  • 8
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
PyTorch Lightning 提供了断点续训的功能,方便在训练过程出现意外情况时恢复训练。要实现断点续训,你需要使用 PyTorch Lightning 提供的回调函数 ModelCheckpoint。 首先,你需要在 LightningModule 定义一个回调函数 ModelCheckpoint,并将其传递给 Trainer。你可以指定保存模型权重的路径、监测的指标以及保存策略等。 下面是一个示例代码: ```python import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint class MyModel(pl.LightningModule): def __init__(self): super().__init__() # 定义模型结构和参数 def training_step(self, batch, batch_idx): # 训练步骤 def validation_step(self, batch, batch_idx): # 验证步骤 def configure_optimizers(self): # 配置优化器 def train_dataloader(self): # 返回训练数据加载器 def val_dataloader(self): # 返回验证数据加载器 # 定义回调函数,设置保存路径和保存策略 checkpoint_callback = ModelCheckpoint( monitor='val_loss', dirpath='/path/to/save/checkpoints/', filename='my_model-{epoch:02d}-{val_loss:.2f}', save_top_k=3, mode='min', ) # 创建 LightningModule 实例和 Trainer 对象 model = MyModel() trainer = pl.Trainer(callbacks=[checkpoint_callback]) # 使用 Trainer 进行训练 trainer.fit(model) ``` 在训练过程,ModelCheckpoint 回调函数会自动保存最好的模型权重,以及根据保存策略保留指定数量的模型权重。如果训练断,你可以通过加载最新的检查点文件来恢复训练。 希望这能帮到你!如果还有其他问题,请随时提问。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值