训练
训练部分已在《入门篇》介绍。
验证集和测试集中评估模型
通常将数据集分为三部分,train/val/test,val集在训练时评估模型的泛化性,选择其中表现最好的checkpoint。test集只在模型训练完成后使用,用于评估模型的真实性能。
添加test流程
划分数据集
以下代码使用torchvision包内实现的MNIST。如果使用自定义的数据集,先用pytorch实现Dataset子类,再继承pl.LightningDataModule类,实现相应接口。
import torch.utils.data as data
from torchvision import datasets
import torchvision.transforms as transforms
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)
实现test_step()接口
在trainer.test()阶段会自动调用test_step方法,根据需要内部可以增加保存图片、评估模型等功能。
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def test_step(self, batch, batch_idx):
# this is the test loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
test_loss = F.mse_loss(x_hat, x)
self.log("test_loss", test_loss)
测试
模型训练完成后,即可调用test()方法进入测试流程
from torch.utils.data import DataLoader
# initialize the Trainer
trainer = Trainer()
# 训练模型
trainer.fit(model, data)
# 训练完成后测试
trainer.test(model, dataloaders=DataLoader(test_set))
验证阶段validation的流程
与test 流程类似,实现validation_step()接口,可以配合on_validation_epoch_end()方法在计算所有样例后评估模型。
class LitAutoEncoder(pl.LightningModule):
def training_step(self, batch, batch_idx):
...
def validation_step(self, batch, batch_idx):
# this is the validation loop
x, y = batch
x = x.view(x.size(0), -1)
z = self.encoder(x)
x_hat = self.decoder(z)
val_loss = F.mse_loss(x_hat, x)
self.log("val_loss", val_loss)
self.metric.update(x_hat, y)# metric是任务相关的评价方法,比如更新混淆矩阵
def on_validation_epoch_step(self, batch, batch_idx):
# 从混淆矩阵中计算tp,fp, tn, fn, acc, F1等指标
score = self.metric.get_scores()
# 记录,横坐标为epoch
self.log('val/F1', score['F1'], logger=True, on_epoch=True)
预测predict流程
实现predict_step方法,然后调用trainer.predict()
其它HOOK见LightningModule,了解LightningModule的接口基本就会用pl了。