文章目录
什么是PyTorch Lightning
PyTorch Lightning(PL)的主要优势包括:
- 训练自动化:PyTorch Lightning可以帮助开发者处理训练循环,包括数据加载、批次迭代、前向传播、损失计算和反向传播等。100行左右的代码就可以写出完整的深度学习项目。
- 分布式训练支持:PyTorch Lightning支持分布式训练,可以在多个GPU或多台机器上进行训练,从而加快训练速度,而且配置特别简单。
- 可复现性:PyTorch Lightning提供的API方便用户使用固定的随机种子和训练环境,确保每次运行的结果是可复现的。
总之,PyTorch Lightning是一个强大而灵活的框架,可以帮助用户更高效地进行深度学习模型的训练和开发。它提供了许多易用的功能和工具,使用户可以更好地管理和组织训练代码,提高工作效率。
常用功能
pl深度学习项目的基本思路:
- 定义PyTorch Lightning Module
- 定义Trainer
- 调用Trainer训练并检验深度学习模块
自动储存训练日志
PL的便捷功能其中之一是在PyTorch中记录包括训练误差、测试误差的训练日志。PL默认使用Tensorboard来记录日志。
要查看日志,可以在终端中运行以下命令:
tensorboard --logdir=lightning_logs/
可以使用on_epoch
参数来确定是否记录每个epoch的累积指标。
trainer = pl.Trainer(max_epochs=MAX_EPOCHS,
num_sanity_val_steps=0, ) # num_sanity_val_steps=0 because of va_spo_list
使用torchmetrics一行代码评估模型
torchmetrics
是一个用于PyTorch深度学习库的指标计算和评估工具包。它提供了一系列常用的评估指标,用于衡量模型在不同任务上的性能,包括分类、回归、分割和生成等。
torchmetrics
支持各种常见的评估指标,如准确率、精确度、召回率、F1分数、AUC、平均绝对误差、均方根误差等。它还提供了一些高级指标,如多类别混淆矩阵、Jaccard系数、Dice系数和IoU等。
torchmetrics
的设计目标是提供一种简洁、灵活和可扩展的方式来计算和记录模型性能指标。它与PyTorch框架紧密集成,可以无缝地与PyTorch的训练和验证流程结合使用。这一点从本文文末提供的代码可以感受得到。
加载训练好的checkpoint
# load the model
CHECKPOINT_PATH = 'lightning_logs/version_9/checkpoints/epoch=59-step=120000.ckpt'
TEMP_VIDEO_PATH = 'tmp_video'
MODEL_TYPE = 'slowfast'
classifier = SignLanguageClassifier.load_from_checkpoint(CHECKPOINT_PATH, strict=False, model_type=MODEL_TYPE)
classifier.model_type = MODEL_TYPE
trainer = pl.Trainer()
# make inference
trainer.test(classifier, test_dataloader)
定义模型运行的设备:CUDA or CPU?
trainer = pl.Trainer(max_epochs=MAX_EPOCH,
devices='auto', accelerator='auto', # 如果只用CPU,把'auto'改成'cpu'就行了
logger=tensorboard_logger)
深度学习实战项目模板
下面是我在实战中用PL写的图片分类代码,对整个数据集进行5折交叉验证后汇报平均准确率和混淆矩阵。这个代码可以很直观地体现出PL的逻辑和整体流程。
"""
a Python script to train ResNet-18 using PyTorch Lightning. The dataset includes 5 categories.
Report the classification accuracy and confusion matrix with torch-metrics.
Use 5-fold stratified sampling.
Report the final average classification accuracies at the end of the program.
"""
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import models, transforms
import torchmetrics
from sklearn.model_selection import StratifiedKFold
import seaborn
import matplotlib.pyplot as plt
MAX_EPOCH = 100
class Classifier(pl.LightningModule):
def __init__(self, num_classes: int, model_type: str = 'resnet18'):
super().__init__()
self.model_type = model_type
if model_type == 'resnet18':
self.model = models.resnet18(pretrained=True)
self.model.fc = torch.nn.Sequential(
torch.nn.Linear(self.model.fc.in_features, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, num_classes)
)
elif model_type == 'mlp':
self.model = torch.nn.Sequential(
torch.nn.Linear(40, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, num_classes)
)
else:
raise ValueError(f'Invalid model_type: {model_type}')
self.accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes)
self.conf_mat = torchmetrics.classification.MulticlassConfusionMatrix(num_classes, normalize='true')
def forward(self, x):
if self.model_type == 'resnet18':
x = x.view(x.size(0), 1, -1, 1) # Reshape 1D data into a single-channel "image"
x = torch.repeat_interleave(x, repeats=3, dim=1)
return self.model(x.float())
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y.long())
self.log('train_loss', loss, )
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_accuracy', self.accuracy, on_epoch=True, prog_bar=True)
self.log('val_loss', F.cross_entropy(y_hat, y.long()), on_step=True, prog_bar=True)
self.conf_mat.update(y_hat, y)
self.accuracy.update(y_hat, y)
def on_validation_end(self):
conf_matrix = self.conf_mat.compute()
print(conf_matrix)
plt.figure()
seaborn.heatmap(conf_matrix.cpu(), annot=True)
plt.savefig(f'conf_mat_{fold_id}.png')
accuracy_computed = self.accuracy.compute()
print(f'Fold Accuracy={accuracy_computed}')
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.00001)
# Load data and labels from .npy file
data_and_labels = np.load('data/data_and_labels.npy', allow_pickle=True).item()
X = data_and_labels['X']
y = data_and_labels['y']
# Prepare 5-fold stratified sampling
skf = StratifiedKFold(n_splits=5, shuffle=True)
# Initialize list for storing classification accuracies
accuracies = []
fold_id = 0
# Perform 5-fold stratified sampling
for train_index, val_index in skf.split(X, y):
X_train, X_val = X[train_index], X[val_index]
y_train, y_val = y[train_index], y[val_index]
# Create TensorDatasets
train_data = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))
val_data = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))
# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
val_loader = DataLoader(val_data, batch_size=64)
# Model
model = Classifier(num_classes=5)
# Training
tensorboard_logger = TensorBoardLogger(save_dir='.', version=fold_id)
trainer = pl.Trainer(max_epochs=MAX_EPOCH, devices='auto', accelerator='auto', logger=tensorboard_logger)
trainer.fit(model, train_loader, val_loader)
fold_id += 1