springboot 实现机器学习_PytorchLightning:优雅的实现机器学习research idea

Pytorch Lightning是一个基于Pytorch的框架,旨在简化机器学习代码的实现,专注于研究想法而不是工程细节。它提供了GPU和TPU支持、参数量显示、进度条、学习率查找和自动调整batch大小等功能,提升训练效率。此外,Lightning还具有自动保存checkpoints和日志的功能。使用Lightning可以在SpringBoot项目中优雅地实施机器学习。
摘要由CSDN通过智能技术生成

cc08e8d794d4d10c8b41b8119a79e59d.png

简介

Pytorch Lightning是基于Pytorch的一个框架,能够很好地组织用Pytorch实现的机器学习代码。它的强大之处在于将科研想法的实现和繁琐的工程代码解耦,让研究者只用关注于科研想法的实现,而工程代码则由Lightning模块帮助你完成,代码变得简洁而优雅。极度适合研究人员。

和学习一门语言一样,要学习一个框架,它背后的设计哲学很重要。

Pytorch Lightning的设计哲学可以概括为以下四点:

  1. 保证代码结构拥有最大的灵活性;
  2. 将不必要的样板代码抽象出来,但是同时保持其在必要时可以访问到;
  3. 系统应该是独立的;
  4. 深度学习的代码可以解耦成四部分,各司其职。首先是数据加载部分,可以使用pytorch提供的dataloader或者使用pytorch lightning的LightningDataModule; 其次是模型部分,这里是research idea主要体现的地方,也就是LightningModule;然后是训练部分,主要工程代码,由pytorch lightning的Trainer完成,不需要自己实现;最后相对次要的research相关代码,比如说logging等,在CallBacks中实现

如何安装

通过PyPI安装

pip install pytorch-lightning

通过conda安装

conda install pytorch-lightning -c conda-forge

Lighting训练的人性化设计

文章末有示例代码,在运行后,可以看到以下界面。

2ac269f10c2de4a313b30085cfd24532.png

有几点我个人觉得还是很惊艳的,可以帮助提高日常工作效率,

  1. 本机GPU以及TPU检测,lightning是支持GPU以及TPU训练的,并且可以很方便的指定具体用哪块GPU来训练
  2. 参数量显示,在网络设计好之后我们可能需要花费时间去计算网络的参数量,lightning可以直接在代码运行后显示出来
  3. 进度条,以及每个epoch,iteration的耗时显示,可以清晰的帮助我们看到时间消耗。
  4. 合理的代码改进建议,比如说在数据加载中,lightning检测到我的机器有4核,但是在数据中我们只使用了一个核,可能会是训练速度的瓶颈,lightning建议通过设置num_workers的值,来提高训练速度。可以帮助我们快速的写出高质量的代码
  5. 完善的模型以及日志机制,在代码运行完成后,lightning会自动帮我们保存相应的checkpoints,对应的日志文件以及hyperparameter

Lightning效率之美

lightning能够将一些繁琐的超参设计流程自动化,帮助我们提高工作效率。

1. Learning Rate Finder

在深度学习训练中,合理的学习率设置可以帮助网络取得更好的表现以及更快的收敛,lighting提供auto_lr_find的设置,在初始训练的时候帮我们找到一个比较合理的学习率,而不是单纯靠猜。

bba0d6e1ef2194b7cfe1b739e2ad59c3.png

2. Auto Scale Batch Size

在机器学习任务中,一般来说,大的batch size可以提供更稳定的梯度,但是人为调节batch size可能会比较繁琐。Pytorch Lightning可以自动根据机器内存的大小缩放batch size,以达到最适合当前任务的一个值。

0b54ed880be1aa7cd2583d413a1c9283.png

3. GPU 16-precision Training

更多自动化操作可以参考这个网址,PyTorch Lightning Documentation

参考资料

https://github.com/PyTorchLightning/pytorch-lightning

https://pytorch-lightning.readthedocs.io/en/latest/introduction_guide.html

样例代码

import os
import torch
import pytorch_lightning as pl
import torch.nn.functional as F

from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split

class LitAutoEncoder(pl.LightningModule):

    def __init__(self) -> None:
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128), 
            nn.ReLU(), 
            nn.Linear(128, 3))
        self.decoder = nn.Sequential(
            nn.Linear(3, 128), 
            nn.ReLU(), 
            nn.Linear(128, 28 * 28))

    def forward(self, x):
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

dataset = MNIST(os.getcwd(), download=True, transform=transforms.ToTensor())

train, val = random_split(dataset, [55000, 5000])

autoencoder = LitAutoEncoder()
trainer = pl.Trainer()
trainer.fit(autoencoder, DataLoader(train), DataLoader(val))
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值