pl.Trainer 是 PyTorch Lightning 的核心类之一,用于简化深度学习训练的流程。gpus 参数用于指定训练过程中使用的 GPU。以下是对 gpus 参数的详细说明:
1. 参数功能
gpus 参数决定 PyTorch Lightning 是否以及如何使用 GPU 进行训练:
- gpus=None: 不使用 GPU,训练将在 CPU 上进行。
- gpus=0: 等价于 gpus=None,不使用 GPU。
- gpus=1: 使用一个 GPU。
- gpus=[0, 1]: 指定使用 GPU 0 和 1。
- gpus=-1: 使用系统中所有可用的 GPU。
- gpus="0,1": 使用 GPU 0 和 1,支持字符串形式。
- gpus=2: 使用两个 GPU,自动选择可用的设备。
2. 多 GPU 训练
如果指定了多个 GPU,PyTorch Lightning 会自动选择合适的分布式训练策略(如 DataParallel 或 DDP)。
- DDP (DistributedDataParallel)
: 推荐的多 GPU 训练策略,具有更高的性能。可以通过设置 strategy='ddp' 启用:
trainer = pl.Trainer(gpus=2, strategy='ddp')
- DataParallel (DP)
: 较旧的多 GPU 方法,效率较低,但易于配置:
trainer = pl.Trainer(gpus=2, strategy='dp')
多 GPU 设置: 使用多个 GPU 时,确保正确配置
strategy 以避免性能问题。
3、auto_select_gpus
有时候,我们不知道那些GPU是被占用的,也就没办法指定GPU,Lightning为此提供了flag,它可以替我们检测可以使用的GPU个数以及序号。
不自动选择GPU (直接选择系统中的前两个GPU, 如果它们被占用则会失败)
trainer = Trainer(gpus=2, auto_select_gpus=False)
自动从系统中选择两个可用的GPU
trainer = Trainer(gpus=2, auto_select_gpus=True)
指定所有的GPU,不管它们是否被占用
Trainer(gpus=-1, auto_select_gpus=False)
指定所有可用的GPU (如果只有一个GPU可用,则只会使用一个GPU)
Trainer(gpus=-1, auto_select_gpus=True)
4、log_gpu_memory
如果我们想要监测GPU内存的使用状况,Lightning也提供了相应的flag。使用的话可能会使训练变慢,因为它使用的是nvidia-smi的输出。
默认不监测GPU内存
trainer = Trainer(log_gpu_memory=None)
监测主节点上的所有GPU
trainer = Trainer(log_gpu_memory='all')
只记录主节点上的最小以及最大GPU内存使用
trainer = Trainer(log_gpu_memory='min_max')
5、benchmark
如果你模型的输入大小保持不变,那么可以设置cudnn.benchmark为True,这样可以加速训练,如果输入大小不固定,那么反而会减慢训练
6. 完整示例
t.py
import pytorch_lightning as pl
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
# 定义一个简单的模型
class SimpleModel(pl.LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(12800, 6400)
self.loss_fn = nn.MSELoss()
def forward(self, x):
return self.layer(x)
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.loss_fn(y_hat, y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=1e-3)
# 生成一些数据
x = torch.randn(1024, 12800)
y = torch.randn(1024, 6400)
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=320)
# 创建 Trainer 并启用混合精度 只调用1显卡
trainer = pl.Trainer(
gpus=[2],
precision=32, # 启用混合精度
max_epochs=2 # 最大训练轮数
)
# 训练模型
model = SimpleModel()
trainer.fit(model, dataloader)
###只调用显卡1
trainer = pl.Trainer(
gpus=1,
precision = 16,
max_epochs=2
)
CUDA_VISIBLE_DEVICES=1 python t.py
###只调用显卡2
trainer = pl.Trainer(
gpus=1,
precision = 16,
max_epochs=2
)
CUDA_VISIBLE_DEVICES=2 python t.py
# 使用accelerator="ddp"同时调用0,2显卡
trainer = pl.Trainer(
gpus=[0,2], # gpus="0,2"也可以
accelerator="ddp",
precision=32, # 启用混合精度
max_epochs=5 # 最大训练轮数
)
python t.py
# 使用accelerator="ddp"调用1显卡
trainer = pl.Trainer(
gpus=[1],
accelerator="ddp",
precision=32, # 启用混合精度
max_epochs=5 # 最大训练轮数
)
python t.py
# 只调用显卡2
trainer = pl.Trainer(
gpus=[2],
precision=32, # 启用混合精度
max_epochs=2 # 最大训练轮数
)
python t.py
###使用多卡必须accelerator="ddp",否则报错
trainer = pl.Trainer(
gpus=[0,2],
precision=32, # 启用混合精度
max_epochs=2 # 最大训练轮数
)
python t.py
Traceback (most recent call last):
File "t4.py", line 56, in <module>
trainer.fit(model, dataloader)
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in fit
self._call_and_handle_interrupt(
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 682, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 770, in _fit_impl
self._run(model, ckpt_path=ckpt_path)
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1193, in _run
self._dispatch()
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1272, in _dispatch
self.training_type_plugin.start_training(self)
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 173, in start_training
self.spawn(self.new_process, trainer, self.mp_queue, return_result=False)
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 201, in spawn
mp.spawn(self._wrapped_function, args=(function, args, kwargs, return_queue), nprocs=self.num_processes)
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
while not context.join():
File "/home/fyf/anaconda3/envs/ldm_fyf/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 139, in join
raise ProcessExitedException(
torch.multiprocessing.spawn.ProcessExitedException: process 0 terminated with exit code 1