“简约版”Pytorch —— Pytorch-Lightning详解

PyTorch Lightning是一个强大的库,旨在抽象PyTorch中的训练细节,使其适合大规模AI研究。它提供了自动调整Batch Size、学习率寻找、数据重新加载等功能,还支持GPU和分布式训练。此外,它包含各种训练循环选项、进度条、回调系统和混合精度训练。通过使用Lightning,可以简化复杂的模型训练和调试过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

介绍

来源:http://blog.itpub.net/31555081/viewspace-2698296/

PyTorch很容易使用,可以用来构建复杂的AI模型。但是一旦研究变得复杂,并且将诸如多GPU训练,16位精度和TPU训练之类的东西混在一起,用户很可能会写出有bug的代码。
PyTorch Lightning完全解决了这个问题。Lightning会构建您的PyTorch代码,以便抽象出训练的详细信息。这使得AI研究可扩展并且可以快速迭代。
PyTorch Lightning是NYUFAIR(Facebook AI research)为从事AI研究的专业研究人员和博士生所创建的。

安装

conda activate your-env
pip install pytorch-lightning

实用功能 (Trainer参数详解)

来源:Lightning官方教程官方文档

自动获取Batch Size - Automatic Batch Size Finder

auto_scale_batch_size

Batch Size一般会对模型的训练结果有影响,一般越大的batch size模型训练的结果会越好。有时候我们不知道自己的模型在当前的机器上最多能用多大的batch size,这时候通过Lightning Trainer的这个flag就可以帮助我们找到最大的batch size。

model = ...
# 设置为True,Trainer就会依次尝试用2的幂次方的batch size,直到超出内存
trainer = pl.Trainer(auto_scale_batch_size=True)
trainer.fit(model)

# 设置为'binsearch',Trainer会用Binary Search的方式帮你找到最大的Batch Size
trainer = pl.Trainer(auto_scale_batch_size='binsearch')
trainer.tune(model)

# 注意:如果要用这个功能,在Module里面的__init__()函数中要有:
self.batch_size = batch_size
# 或者在__init__()里面调用:
self.save_hyperparameters()

自动获取初始学习率 - Automatic Learning Rate Finder

auto_lr_find

学习率learning rate是很重要的一个超参,选取一个合适的初始学习率也是很重要的,Lightning提供了这个有用的flag。
(15.11.2020)目前只支持单优化器(Optimizer),预计在未来的几个月内会支持多优化器

import pytorch_lightning as pl

model = ...
# 可以直接设置为True,Trainer会自动用不同的学习率运行model,然后画出loss和学习率的曲线,帮你找到最合适的学习率
trainer = pl.Trainer(auto_lr_find=True)
trainer.tune(model)
print(model.learning_rate)

# 有时候我们会在model中给学习率起其他的名字,比如:
self.my_learning_rate = lr
# 这个时候我们可以用变量名直接设置auto_lr_find:
trainer = pl.Trainer(auto_lr_find='my_learing_rate')
# 开始寻找合适的学习率
lr_finder = trainer.tuner.lr_find(model)
# 展示loss和学习率的曲线
fig = lr_finder.plot(suggest=True)
fig.show()
# 设置为推荐的学习率
model.hyparams.learning_rate = lr_finder.suggestion()
# 开始训练
model.fit(model, train_loader, val_loader)

重新加载数据 - Reload DataLoaders Every Epoch

reload_dataloaders_every_epoch

一般数据只会在一开始加载一次,即在epochs的循环前面加载一次,然后每个循环都会shuffle之类的(如果你设置shuffle为True的话)。有时候我们的数据在训练过程中是会改变的,这个时候我们就需要在每个epoch都要再加载一次数据,Lightning就提供了这样一个flag,将其设置为True即可。

# 相当于:
# if False (default)
train_loader = model.train_dataloader()
for epoch in epochs:
    for batch in train_loader:
        ...

# if True
for epoch in epochs:
    train_loader = model.train_dataloader()
    for batch in train_loader:

回调函数 - Callbacks

callbacks

回调函数 (Callbacks) 在机器学习中也是很重要的工具,一般可以用来进行模型的断点断续,模型权重的存储,提早停止 (Early stop),动态调整训练参数以及tensorboard之类的训练可视化等等。Lightning也支持非常灵活的Callbacks,只需要把Callbacks放进flag:callbacks中即可。Lightning提供了一些built-in的callbacks,同样也支持自定义callbacks类,所以非常灵活。
官方的callbacks说明

# 自定义Callback类
from pytorch_lightning.callbacks import Callback

class MyPrintingCallback(Callback):

    def on_init_start(self, trainer):
        print('Starting to init trainer!')

    def on_init_end(self, trainer):
        print('trainer is init now')

    def on_train_end(self, trainer, pl_module):
        print('do something when training ends')

trainer = Trainer(callbacks=[MyPrintingCallback()])

# 使用built-in的Callbacks
from pytorch_lightning.callbacks import EarlyStopping

# 可以直接使用默认的Callbacks
trainer = pl.Trainer(callbacks
评论 12
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值