PyTorch Lightning教程二:验证、测试、checkpoint、早停策略

介绍:上一期介绍了如何利用PyTorch Lightning搭建并训练一个模型(仅使用训练集),为了保证模型可以泛化到未见过的数据上,数据集通常被分为训练和测试两个集合,测试集与训练集相互独立,用以测试模型的泛化能力。本期通过增加验证和测试集来达到该目的,同时,还引入checkpoint和早停策略,以得到模型最佳权重。

相关链接:https://lightning.ai/docs/pytorch/stable/levels/basic_level_2.html

训练集、验证集、测试集的使用

1.添加依赖,获取训练集和测试集

添加相应的依赖,同时使用MNIST数据集,获取训练和测试集

import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# 加载数据(测试集,train=False)
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
2.实现并调用test_step

在定义LightningModule中,实现test_step方法;在外部,调用test方法

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def test_step(self, batch, batch_idx): # 测试,该方法与training_step相似
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

# 初始化Trainer
trainer = Trainer()

# 执行test方法
trainer.test(model, dataloaders=DataLoader(test_set))
3.实现并调用验证集

通常使用torch.utils.data中的方法,将训练集中的一部分数据化为验证集

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)

与测试集一样,需要在定义LightningModule中,实现validation_step方法;在外部,调用fit方法

class LitAutoEncoder(pl.LightningModule):
    def training_step(self, batch, batch_idx):
        ...

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)
    
    def test_step(self, batch, batch_idx):
        ...
# 调用torch.utils.data中的DataLoader对训练和测试集进行封装
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

# 在fit方法中,引入valid_loader,即验证集
trainer = Trainer()
trainer.fit(model, train_loader, valid_loader)

checkpoint

checkpoint有两个作用,一是能得到每一次epoch后的模型权重,能得到最佳表现的权重;二是能够在中断或停止后,继续在当前checkpoint处,继续训练。在Lightning中的checkpoint,包含模型的整个内部状态,这与普通的PyTorch不同,即使在最复杂的分布式训练环境中,Lightning也可以保存恢复模型所需的一切。包含以下状态:

  • 16-bit scaling factor (若使用16精度训练)
  • Current epoch
  • Global step
  • LightningModule’s state_dict
  • State of all optimizers
  • State of all learning rate schedulers
  • State of all callbacks (for stateful callbacks)
  • State of datamodule (for stateful datamodules)
  • The hyperparameters (init arguments) with which the model was created
  • The hyperparameters (init arguments) with which the datamodule was created
  • State of Loops
保存与调用方法
# 保存方法,可自定义default_root_dir路径,若不设置路径,将会自动保存
trainer = Trainer(default_root_dir="some/path/")

# 调用方法
model = MyLightningModule.load_from_checkpoint("/path/to/checkpoint.ckpt")
model.eval()	# disable randomness, dropout, etc...
y_hat = model(x)

调用,还可以使用torch的方法

checkpoint = torch.load(checkpoint, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

也可以实现重现,例如模型LitModel(in_dim=32, out_dim=10)

# 使用 in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# 使用 in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)

Lightning和PyTorch完全兼容

checkpoint = torch.load(CKPT_PATH)
encoder_weights = checkpoint["encoder"]
decoder_weights = checkpoint["decoder"]

设置checkpoint不可见

trainer = Trainer(enable_checkpointing=False)

如果想全部重新恢复

model = LitModel()
trainer = Trainer()

自动恢复所有相关参数 model, epoch, step, LR schedulers, etc…

trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")

早停策略

EarlyStopping Callback

在Lightning中,早停回调步骤如下:

  • Import EarlyStopping callback. 载入EarlyStopping回调方法
  • Log the metric you want to monitor using log() method. 加载日志方法
  • Init the callback, and set monitor to the logged metric of your choice. 设置monitor
  • Set the mode based on the metric needs to be monitored. 设置mode
  • Pass the EarlyStopping callback to the Trainer callbacks flag. 调入EarlyStropping
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

class LitModel(LightningModule):
    def validation_step(self, batch, batch_idx):
        loss = ...
        self.log("val_loss", loss)

model = LitModel()
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model)

# 也可以使用自定义的EarlyStopping策略
early_stop_callback = EarlyStopping(monitor="val_accuracy", min_delta=0.00, patience=3, verbose=False, mode="max")
trainer = Trainer(callbacks=[early_stop_callback])
# EarlyStopping的文档链接https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping
注意

完整代码

# coding:utf-8
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import lightning as L

# --------------------------------
# Step 1: 定义模型
# --------------------------------
class LitAutoEncoder(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(28 * 28, 128), nn.ReLU(), nn.Linear(128, 3))
        self.decoder = nn.Sequential(nn.Linear(3, 128), nn.ReLU(), nn.Linear(128, 28 * 28))

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        return loss

    def test_step(self, batch, batch_idx):  # 测试,该方法与training_step相似
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

    def forward(self, x):
        # forward 定义了一次 预测/推理 行为
        embedding = self.encoder(x)
        return embedding
# --------------------------------
# Step 2: 加载数据+模型
# --------------------------------
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

# 训练集中的20%数据划为验证集
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# 拆分,使用data.random_split方法
seed = torch.Generator().manual_seed(42)
train_set, valid_set = data.random_split(train_set, [train_set_size, valid_set_size], generator=seed)
train_loader = DataLoader(train_set)
valid_loader = DataLoader(valid_set)

autoencoder = LitAutoEncoder()
# --------------------------------
# Step 3: 训练+验证+测试
# --------------------------------
# 训练+验证
trainer = L.Trainer(default_root_dir="some/path/")	# 这里自定义需要保存的路径
trainer.fit(autoencoder, train_loader, valid_loader)

# 测试
trainer.test(autoencoder, dataloaders=DataLoader(test_set))
PyTorch Lightning是一种轻量级的高级PyTorch封装,它使得训练神经网络更加容易、更加模块化。它提供了许多常用的功能,例如自动分布式训练、自动检查点、自动日志记录等等。下面是一个PyTorch Lightning的学习指南: 1. 先学习PyTorch基础知识:在学习PyTorch Lightning之前,您需要先学习PyTorch的基础知识,例如如何构建神经网络、如何训练模型等等。 2. 安装PyTorch Lightning:在安装PyTorch Lightning之前,您需要先安装PyTorch。然后可以通过pip安装PyTorch Lightning。 3. 了解PyTorch Lightning的核心概念:PyTorch Lightning的核心概念是“LightningModule”、“Trainer”和“DataModule”。LightningModule是您定义神经网络的地方,Trainer是您定义训练过程的地方,DataModule是您定义数据集的地方。 4. 编写您的第一个PyTorch Lightning程序:您可以从一个简单的例子开始,例如MNIST手写数字识别。在这个例子中,您可以定义一个LightningModule来构建神经网络,定义一个DataModule来加载数据集,然后定义一个Trainer来训练模型。 5. 学习如何自动分布式训练:PyTorch Lightning可以自动进行分布式训练,这意味着您可以在多个GPU或多台计算机上训练模型。您只需要在Trainer中设置一些参数即可。 6. 学习如何自动检查点和日志记录:PyTorch Lightning可以自动保存检查点和记录日志,这使得您可以在训练过程中随时恢复模型并查看训练指标。 7. 学习如何使用PyTorch Lightning扩展您的研究:PyTorch Lightning提供了许多扩展功能,例如自动优化器、自动批量大小调整、自动对抗性训练等等。您可以使用这些功能来扩展您的研究。 总之,PyTorch Lightning是一个非常强大的工具,可以使训练神经网络更加容易和高效。如果您想提高您的PyTorch技能并加快训练过程,请考虑学习PyTorch Lightning
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小风_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值