pytorch lightning 按照频率/epoch/step保存模型或checkpoint

需求

在训练深度神经网络时,如果训练时间较长,我们通常希望在训练过程中定期保存模型的参数,以便稍后从该点恢复训练或进行推理。PyTorch Lightning 提供了 ModelCheckpoint 回调函数来帮助我们自动保存模型参数。
在本文中,我们将探讨如何使用 PyTorch Lightning 训练模型并使用 ModelCheckpoint 自动从训练过程中保存模型的参数。

方法

pytorch lightning 提供了保存 checkpoint API https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint
利用 **every_n_train_steps 、train_time_interval 、every_n_epochs **设置保存 checkpoint 的按照步数、时间、epoch数来保存 checkpoints,注意三者互斥,如果要同时实现对应的功能需要创建多个 MODELCHECKPOINT
利用 save_top_k 设置保存所有的模型,因为不是通过 monitor 的形式保存的,所以 save_top_k 只能设置为 -1,0,1,分别表示保存所有的模型,不保存模型和保存最后一个模型

示例

准备数据集和模型

首先,我们需要准备数据集和模型。这里我们使用 PyTorch 的 FashionMNIST 数据集和一个简单的卷积神经网络作为模型。

import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class Net(pl.LightningModule):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def cross_entropy_loss(self, logits, labels):
        return nn.functional.cross_entropy(logits, labels)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('val_loss', loss, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.01)
        return optimizer

# 准备数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
train_set = FashionMNIST(".", train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
val_set = FashionMNIST(".", train=False, transform=transform, download=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=2)

训练模型

接下来,我们使用 Trainer 类来训练模型。

# 训练模型
trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)]
)
model = Net(num_classes=10)
trainer.fit(model, train_loader, val_loader)

上面的代码将训练模型 5 个 epoch,并在每训练 60 步(batch)时保存一个 checkpoint。ModelCheckpoint 回调函数的 save_top_k 参数为 -1,表示保存所有 checkpoint。

加载已保存的模型

当我们需要从保存的 checkpoint 恢复模型时,可以使用 Trainer 类的 resume_from_checkpoint 参数:

trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)],
    resume_from_checkpoint='path/to/checkpoint.ckpt'
)
model = Net(num_classes=10)
trainer.fit(model, train_loader, val_loader)
  • 5
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值