PyTorch深度学习必用库之PyTorch Lightning新手教程(附100行的完整深度学习项目代码模板)

什么是PyTorch Lightning

PyTorch Lightning(PL)的主要优势包括:

  • 训练自动化:PyTorch Lightning可以帮助开发者处理训练循环,包括数据加载、批次迭代、前向传播、损失计算和反向传播等。100行左右的代码就可以写出完整的深度学习项目。
  • 分布式训练支持:PyTorch Lightning支持分布式训练,可以在多个GPU或多台机器上进行训练,从而加快训练速度,而且配置特别简单。
  • 可复现性:PyTorch Lightning提供的API方便用户使用固定的随机种子和训练环境,确保每次运行的结果是可复现的。

总之,PyTorch Lightning是一个强大而灵活的框架,可以帮助用户更高效地进行深度学习模型的训练和开发。它提供了许多易用的功能和工具,使用户可以更好地管理和组织训练代码,提高工作效率。

常用功能

pl深度学习项目的基本思路:

  1. 定义PyTorch Lightning Module
  2. 定义Trainer
  3. 调用Trainer训练并检验深度学习模块

自动储存训练日志

PL的便捷功能其中之一是在PyTorch中记录包括训练误差、测试误差的训练日志。PL默认使用Tensorboard来记录日志。

要查看日志,可以在终端中运行以下命令:

tensorboard --logdir=lightning_logs/

可以使用on_epoch参数来确定是否记录每个epoch的累积指标。

trainer = pl.Trainer(max_epochs=MAX_EPOCHS, 
num_sanity_val_steps=0,  ) # num_sanity_val_steps=0 because of va_spo_list

使用torchmetrics一行代码评估模型

torchmetrics是一个用于PyTorch深度学习库的指标计算和评估工具包。它提供了一系列常用的评估指标,用于衡量模型在不同任务上的性能,包括分类、回归、分割和生成等。

torchmetrics支持各种常见的评估指标,如准确率、精确度、召回率、F1分数、AUC、平均绝对误差、均方根误差等。它还提供了一些高级指标,如多类别混淆矩阵、Jaccard系数、Dice系数和IoU等。

torchmetrics的设计目标是提供一种简洁、灵活和可扩展的方式来计算和记录模型性能指标。它与PyTorch框架紧密集成,可以无缝地与PyTorch的训练和验证流程结合使用。这一点从本文文末提供的代码可以感受得到。

加载训练好的checkpoint

# load the model  
CHECKPOINT_PATH = 'lightning_logs/version_9/checkpoints/epoch=59-step=120000.ckpt'  
TEMP_VIDEO_PATH = 'tmp_video'  
MODEL_TYPE = 'slowfast'  
classifier = SignLanguageClassifier.load_from_checkpoint(CHECKPOINT_PATH, strict=False, model_type=MODEL_TYPE)  
classifier.model_type = MODEL_TYPE  
trainer = pl.Trainer()

# make inference 
trainer.test(classifier, test_dataloader)

定义模型运行的设备:CUDA or CPU?

trainer = pl.Trainer(max_epochs=MAX_EPOCH, 
devices='auto', accelerator='auto', 	# 如果只用CPU,把'auto'改成'cpu'就行了
logger=tensorboard_logger)

深度学习实战项目模板

下面是我在实战中用PL写的图片分类代码,对整个数据集进行5折交叉验证后汇报平均准确率和混淆矩阵。这个代码可以很直观地体现出PL的逻辑和整体流程。

"""  
a Python script to train ResNet-18 using PyTorch Lightning. The dataset includes 5 categories.  
Report the classification accuracy and confusion matrix with torch-metrics.  
  
Use 5-fold stratified sampling.  
Report the final average classification accuracies at the end of the program.  
"""  
  
  
import numpy as np  
import pytorch_lightning as pl  
from pytorch_lightning.loggers import TensorBoardLogger  
import torch  
from torch.nn import functional as F  
from torch.utils.data import DataLoader, TensorDataset  
from torchvision import models, transforms  
import torchmetrics  
from sklearn.model_selection import StratifiedKFold  
import seaborn  
import matplotlib.pyplot as plt  
  
  
MAX_EPOCH = 100  
  
  
class Classifier(pl.LightningModule):  
    def __init__(self, num_classes: int, model_type: str = 'resnet18'):  
        super().__init__()  
        self.model_type = model_type  
        if model_type == 'resnet18':  
            self.model = models.resnet18(pretrained=True)  
            self.model.fc = torch.nn.Sequential(  
                    torch.nn.Linear(self.model.fc.in_features, 128),  
                    torch.nn.ReLU(),  
                    torch.nn.Linear(128, 64),  
                    torch.nn.ReLU(),  
                    torch.nn.Linear(64, num_classes)  
                    )  
        elif model_type == 'mlp':  
            self.model = torch.nn.Sequential(  
                torch.nn.Linear(40, 128),  
                torch.nn.ReLU(),  
                torch.nn.Linear(128, 64),  
                torch.nn.ReLU(),  
                torch.nn.Linear(64, num_classes)  
            )  
        else:  
            raise ValueError(f'Invalid model_type: {model_type}')  
        self.accuracy = torchmetrics.classification.MulticlassAccuracy(num_classes)  
        self.conf_mat = torchmetrics.classification.MulticlassConfusionMatrix(num_classes, normalize='true')  
  
    def forward(self, x):  
        if self.model_type == 'resnet18':  
            x = x.view(x.size(0), 1, -1, 1)    # Reshape 1D data into a single-channel "image"  
            x = torch.repeat_interleave(x, repeats=3, dim=1)  
        return self.model(x.float())  
  
    def training_step(self, batch, batch_idx):  
        x, y = batch  
        y_hat = self(x)  
        loss = F.cross_entropy(y_hat, y.long())  
        self.log('train_loss', loss, )  
        return loss  
  
    def validation_step(self, batch, batch_idx):  
        x, y = batch  
        y_hat = self(x)  
        self.log('val_accuracy', self.accuracy, on_epoch=True, prog_bar=True)  
        self.log('val_loss', F.cross_entropy(y_hat, y.long()), on_step=True, prog_bar=True)  
        self.conf_mat.update(y_hat, y)  
        self.accuracy.update(y_hat, y)  
  
    def on_validation_end(self):  
        conf_matrix = self.conf_mat.compute()  
        print(conf_matrix)  
        plt.figure()  
        seaborn.heatmap(conf_matrix.cpu(), annot=True)  
        plt.savefig(f'conf_mat_{fold_id}.png')  
        accuracy_computed = self.accuracy.compute()  
        print(f'Fold Accuracy={accuracy_computed}')  
  
    def configure_optimizers(self):  
        return torch.optim.Adam(self.parameters(), lr=0.00001)  
  
# Load data and labels from .npy file  
data_and_labels = np.load('data/data_and_labels.npy', allow_pickle=True).item()  
X = data_and_labels['X']  
y = data_and_labels['y']  
  
# Prepare 5-fold stratified sampling  
skf = StratifiedKFold(n_splits=5, shuffle=True)  
  
# Initialize list for storing classification accuracies  
accuracies = []  
fold_id = 0  
# Perform 5-fold stratified sampling  
for train_index, val_index in skf.split(X, y):  
    X_train, X_val = X[train_index], X[val_index]  
    y_train, y_val = y[train_index], y[val_index]  
  
    # Create TensorDatasets  
    train_data = TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train))  
    val_data = TensorDataset(torch.from_numpy(X_val), torch.from_numpy(y_val))  
  
    # Create DataLoaders  
    train_loader = DataLoader(train_data, batch_size=64, shuffle=True)  
    val_loader = DataLoader(val_data, batch_size=64)  
  
    # Model  
    model = Classifier(num_classes=5)  
  
    # Training  
    tensorboard_logger = TensorBoardLogger(save_dir='.', version=fold_id)  
    trainer = pl.Trainer(max_epochs=MAX_EPOCH, devices='auto', accelerator='auto', logger=tensorboard_logger)  
    trainer.fit(model, train_loader, val_loader)  
  
    fold_id += 1
  • 10
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch Lightning是一种轻量级的高级PyTorch封装,它使得训练神经网络更加容易、更加模块化。它提供了许多常用的功能,例如自动分布式训练、自动检查点、自动日志记录等等。下面是一个PyTorch Lightning的学习指南: 1. 先学习PyTorch基础知识:在学习PyTorch Lightning之前,您需要先学习PyTorch的基础知识,例如如何构建神经网络、如何训练模型等等。 2. 安装PyTorch Lightning:在安装PyTorch Lightning之前,您需要先安装PyTorch。然后可以通过pip安装PyTorch Lightning。 3. 了解PyTorch Lightning的核心概念:PyTorch Lightning的核心概念是“LightningModule”、“Trainer”和“DataModule”。LightningModule是您定义神经网络的地方,Trainer是您定义训练过程的地方,DataModule是您定义数据集的地方。 4. 编写您的第一个PyTorch Lightning程序:您可以从一个简单的例子开始,例如MNIST手写数字识别。在这个例子中,您可以定义一个LightningModule来构建神经网络,定义一个DataModule来加载数据集,然后定义一个Trainer来训练模型。 5. 学习如何自动分布式训练:PyTorch Lightning可以自动进分布式训练,这意味着您可以在多个GPU或多台计算机上训练模型。您只需要在Trainer中设置一些参数即可。 6. 学习如何自动检查点和日志记录:PyTorch Lightning可以自动保存检查点和记录日志,这使得您可以在训练过程中随时恢复模型并查看训练指标。 7. 学习如何使用PyTorch Lightning扩展您的研究:PyTorch Lightning提供了许多扩展功能,例如自动优化器、自动批量大小调整、自动对抗性训练等等。您可以使用这些功能来扩展您的研究。 总之,PyTorch Lightning是一个非常强大的工具,可以使训练神经网络更加容易和高效。如果您想提高您的PyTorch技能并加快训练过程,请考虑学习PyTorch Lightning

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值