pytorch lightning 按照频率/epoch/step保存模型或checkpoint

需求

在训练深度神经网络时,如果训练时间较长,我们通常希望在训练过程中定期保存模型的参数,以便稍后从该点恢复训练或进行推理。PyTorch Lightning 提供了 ModelCheckpoint 回调函数来帮助我们自动保存模型参数。
在本文中,我们将探讨如何使用 PyTorch Lightning 训练模型并使用 ModelCheckpoint 自动从训练过程中保存模型的参数。

方法

pytorch lightning 提供了保存 checkpoint API https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint
利用 **every_n_train_steps 、train_time_interval 、every_n_epochs **设置保存 checkpoint 的按照步数、时间、epoch数来保存 checkpoints,注意三者互斥,如果要同时实现对应的功能需要创建多个 MODELCHECKPOINT
利用 save_top_k 设置保存所有的模型,因为不是通过 monitor 的形式保存的,所以 save_top_k 只能设置为 -1,0,1,分别表示保存所有的模型,不保存模型和保存最后一个模型

示例

准备数据集和模型

首先,我们需要准备数据集和模型。这里我们使用 PyTorch 的 FashionMNIST 数据集和一个简单的卷积神经网络作为模型。

import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from torch.utils.data import DataLoader
import pytorch_lightning as pl

class Net(pl.LightningModule):
    def __init__(self, num_classes=10):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

    def cross_entropy_loss(self, logits, labels):
        return nn.functional.cross_entropy(logits, labels)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.cross_entropy_loss(logits, y)
        self.log('val_loss', loss, logger=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=0.01)
        return optimizer

# 准备数据集
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])
train_set = FashionMNIST(".", train=True, transform=transform, download=True)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=2)
val_set = FashionMNIST(".", train=False, transform=transform, download=True)
val_loader = DataLoader(val_set, batch_size=64, shuffle=False, num_workers=2)

训练模型

接下来,我们使用 Trainer 类来训练模型。

# 训练模型
trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)]
)
model = Net(num_classes=10)
trainer.fit(model, train_loader, val_loader)

上面的代码将训练模型 5 个 epoch,并在每训练 60 步(batch)时保存一个 checkpoint。ModelCheckpoint 回调函数的 save_top_k 参数为 -1,表示保存所有 checkpoint。

加载已保存的模型

当我们需要从保存的 checkpoint 恢复模型时,可以使用 Trainer 类的 resume_from_checkpoint 参数:

trainer = pl.Trainer(
    gpus=1,
    max_epochs=5,
    callbacks=[pl.callbacks.ModelCheckpoint(every_n_train_steps=60, save_top_k=-1)],
    resume_from_checkpoint='path/to/checkpoint.ckpt'
)
model = Net(num_classes=10)
trainer.fit(model, train_loader, val_loader)
PyTorch Lightning是一个高级封装库,用于简化使用PyTorch构建模型的过程。它可以帮助用户更好地组织代码,使得研究代码更易读和维护。PyTorch Lightning中内置了ModelCheckpoint功能,主要用于在训练过程中自动保存模型的最佳版本,防止因过拟合训练过程中断而丢失模型。 ModelCheckpoint可以设置很多参数,比如保存频率保存条件、保存的文件名等。使用ModelCheckpoint时,用户可以指定监控某个指标(例如验证集上的准确率),并根据这个指标来保存最好的模型者在每个epoch保存者只保存最新的模型。 下面是一个使用PyTorch Lightning的ModelCheckpoint的基本示例: ```python from pytorch_lightning.callbacks import ModelCheckpoint # 创建ModelCheckpoint的回调实例 checkpoint_callback = ModelCheckpoint( monitor='val_loss', # 监控的指标,这里是验证集上的损失 dirpath='path/to/save', # 模型保存的路径 filename='model-{epoch:02d}-{val_loss:.2f}', # 文件名格式 save_top_k=3, # 保存top k个模型,这里是保存最好的3个 mode='min', # 指定监控指标是希望最小化(min)还是最大化(max),这里是损失最小化 save_weights_only=False, # 是否只保存模型权重,默认为False,即保存整个模型 period=1 # 指定多少个epoch保存一次,默认是每个epoch保存 ) # 定义Lightning模块 class LitModel(pl.LightningModule): def __init__(self): super().__init__() # 模型定义代码 def training_step(self, batch, batch_idx): # 训练步骤代码 def validation_step(self, batch, batch_idx): # 验证步骤代码 def configure_optimizers(self): # 配置优化器代码 # 实例化模型并添加ModelCheckpoint回调 model = LitModel() trainer = pl.Trainer(callbacks=[checkpoint_callback]) trainer.fit(model) ``` 在上面的代码中,我们创建了一个ModelCheckpoint实例,并将其作为回调添加到了Trainer中。在训练过程中,根据监控的指标(例如验证集损失),ModelCheckpoint会自动选择并保存表现最佳的模型
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值