PyTorch 模型保存,断点训练

在epoch前插入:

    initepoch = 0
    resume = True  # 设置是否需要从上次的状态继续训练
    if resume:
        if os.path.isfile("./testweights/last_model.pth"):
            print("Resume from checkpoint...")
            checkpoint = torch.load("./testweights/last_model.pth")
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            initepoch = checkpoint['epoch'] + 1
            print("====>loaded checkpoint (epoch{})".format(checkpoint['epoch']))
        else:
            print("====>no checkpoint found.")
            initepoch = 0  # 如果没进行训练过,初始训练epoch值为0

epoch循环改为:

for epoch in range(initepoch, args.epochs):

在epoch中插入:

        # save best epoch
        if val_acc > best_acc:
            best_acc = val_acc
            torch.save(model.state_dict(), "testweights/best_model.pth")
            print("!!--Best Model has Update--!!")

        # save epoch model
        torch.save(model.state_dict(), "./testweights/model-{}.pth".format(epoch))

        # save last model
        checkpoint = {"model_state_dict": model.state_dict(),
                      "optimizer_state_dict": optimizer.state_dict(),
                      "epoch": epoch}
        path_checkpoint = "./testweights/last_model.pth"
        torch.save(checkpoint, path_checkpoint)
        print("!!--Last Model has Update(-{})--!!".format(epoch))

### PyTorch在大规模模型训练中的作用与特性 #### 1. 动态计算图支持 PyTorch 提供了一种动态计算图机制,允许开发者在运行构建和修改计算图。这种灵活性特别适合处理复杂的自然语言处理 (NLP) 和序列建模任务,例如 ELMo 模型的实现[^2]。相比于静态图框架,动态图更易于调试并能更好地适配复杂的数据流。 #### 2. 高效的GPU加速 当面对像 ImageNet 这样的大规模数据集或者涉及数百万样本的任务,利用 GPU 加速变得至关重要。PyTorch 对 CUDA 的原生支持使模型能够在 GPU 设备上高效执行矩阵运算和其他密集操作,显著缩短训练间[^4]。例如,在一个包含百万级数据点的场景下,借助 GPU 可以将整个训练周期控制在合理范围内(约5小完成20个epoch)。 #### 3. 轻松加载预训练模型 迁移学习是解决小样本问题的有效手段之一。通过使用预先在大数据集如 ImageNet 上训练过的模型作为起点,可以快速调整其参数以适应特定领域的新任务需求[^1]。这不仅减少了重新从零开始所需的间成本,还提高了最终解决方案的质量。此外,PyTorch 社区提供了丰富的官方及第三方预训练模型库,方便用户直接下载使用。 #### 4. 断点续训功能 针对长间运行的大规模训练项目,不可避免会出现意外中断的情况。为此,PyTorch 支持定期保存当前状态字典(state_dict),其中包括网络结构、权重以及优化器的信息等。这样即便中途出现问题也可以恢复至最近一次存档位置继续工作而无需重头再来一遍。 #### 5. 完整工具链生态系统 除了核心框架外,围绕着 PyTorch 形成了众多辅助开发插件和服务平台,比如用于分布式训练的 torch.distributed API 或者可视化监控进度曲线变化趋势 tensorboardX 工具包等等 。它们共同构成了强大的技术支持体系帮助研究人员更加专注于算法设计本身而非底层细节管理方面的工作负担。 ```python import torch from torchvision import models, transforms from torch.utils.data import DataLoader # 使用预训练模型 ResNet-50 并冻结部分层 model = models.resnet50(pretrained=True) for param in model.parameters(): param.requires_grad = False num_ftrs = model.fc.in_features model.fc = torch.nn.Linear(num_ftrs, new_classes_count) criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.fc.parameters(), lr=learning_rate) device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) def train_model(dataloader_train, dataloader_val): best_accuracy = 0.0 for epoch in range(epochs): running_loss = 0.0 # 训练阶段 model.train() for inputs, labels in dataloader_train: optimizer.zero_grad() outputs = model(inputs.to(device)) loss = criterion(outputs, labels.to(device)) loss.backward() optimizer.step() running_loss += loss.item() avg_loss = running_loss / len(dataloader_train.dataset) # 验证阶段省略... checkpoint_path = f'model_epoch_{epoch}.pt' torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), }, checkpoint_path) ``` 上述代码片段展示了如何基于 PyTorch 构造一个简单的迁移学习流程,并包含了基本的断点存储逻辑以便后续可能需要用到的地方能够顺利衔接起来。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值