-
Pytorch-Lightning-Learning
0.引言
-
PyTorch Lightning完全继承于Pytorch,Pytorch的所有东西都可以在PyTorch Lightning中使用,PyTorch Lightning的所有物品也可在Pytorch中使用。
-
PyTorch Lightning作为Pytorch的高级封装,内部包含着完整且可修改的训练逻辑。
-
PyTorch Lightning的硬件检测基于Pytorch,也可以使用Trainer修改。
-
PyTorch Lightning中数据类型自动变化,无需.cpu和.cuda。
1.用Pytorch的Dataset和DataLoader定义数据集
2. 用LightningModule定义模型并实现训练逻辑
- LightningModule:Pytorch Lightning的两大API之一,是torch.nn.Module的高级封装。
(1)定义模型
-
__init__()
:同torch.nn.Module中的__init__,用于构建模型。 -
forward(*args, **kwargs)
:同torch.nn.Module中的forward,通过__init__中的各个模块实现前向传播。
(2)训练模型
#训练模型
training_step(*args, **kwargs)
"""
训练一批数据并反向传播。参数如下:
- batch (Tensor | (Tensor, …) | [Tensor, …]) – 数据输入,一般为x, y = batch。
- batch_idx (int) – 批次索引。
- optimizer_idx (int) – 当使用多个优化器时,会使用本参数。
- hiddens (Tensor) – 当truncated_bptt_steps > 0时使用。
"""
#举个例子:
def training_step(self, batch, batch_idx): # 数据类型自动转换,模型自动调用.train()
x, y = batch
_y = self(x)
loss = criterion(_y, y) # 计算loss
return loss # 返回loss,更新网络
def training_step(self, batch, batch_idx, hiddens):
# hiddens是上一次截断反向传播的隐藏状态
out, hiddens = self.lstm(data, hiddens)
return {"loss": loss, "hiddens": hiddens}
#--------------------------------------
training_step_end(*args, **kwargs)
"""一批数据训练结束时的操作。一般用不着,分布式训练的时候会用上。参数如下:
- batch_parts_outputs – 当前批次的training_step()的返回值
"""
#举个例子:
def training_step(self, batch, batch_idx):
x, y = batch
_y = self(x)
return {"output": _y, "target": y}
def training_step_end(self, training_step_outputs): # 多GPU分布式训练,计算loss
gpu_0_output = training_step_outputs[0]["output"]
gpu_1_output = training_step_outputs[1]["output"]
gpu_0_target = training_step_outputs[0]["target"]
gpu_1_target = training_step_outputs[1]["target"]
# 对所有GPU的数据进行处理
loss = criterion([gpu_0_output, gpu_1_output], [gpu_0_target, gpu_1_target])
return loss
#--------------------------------------
training_epoch_end(outputs)
"""一轮数据训练结束时的操作。主要针对于本轮所有training_step的输出。参数如下:
- outputs (List[Any]) – training_step()的输出。
"""
#举个例子:
def training_epoch_end(self, outs): # 计算本轮的loss和acc
loss = 0.
for out in outs: # outs按照训练顺序排序
loss += out["loss"].cpu().detach().item()
loss /= len(outs)
acc = self.train_metric.compute()
self.history["loss"].append(loss)
self.history["acc"].append(acc)
三个核心组件:
- 模型
- 优化器
- Train/Val/Test步骤
数据流伪代码:
outs = []
for batch in data:
out = training_step(batch)
outs.append(out)
training_epoch_end(outs)
等价Lightning代码:
def training_step(self, batch, batch_idx):
prediction = ...
return prediction
def training_epoch_end(self, training_step_outputs):
for prediction in predictions:
# do something with these
我们需要做的,就是像填空一样,填这些函数。
3.用Trainer配置参数进行自动训练
- Pytorch Lightning的两大API之一,类似于“胶水”,将LightningModule各个部分连接形成完整的逻辑。
(1)方法
__init__(logger=True, checkpoint_callback=True, callbacks=None, \
default_root_dir=None, gradient_clip_val=0, process_position=0, num_nodes=1, \
num_processes=1, gpus=None, auto_select_gpus=False, tpu_cores=None, \
log_gpu_memory=None, progress_bar_refresh_rate=1, overfit_batches=0.0, \
track_grad_norm=-1, check_val_every_n_epoch=1, fast_dev_run=False, \
accumulate_grad_batches=1, max_epochs=1000, min_epochs=1, max_steps=None, \
min_steps=None, limit_train_batches=1.0, limit_val_batches=1.0, \
limit_test_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, \
log_every_n_steps=50, accelerator=None, sync_batchnorm=False, precision=32, \
weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, \
truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, \
benchmark=False, deterministic=False, reload_dataloaders_every_epoch=False, \
auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, \
auto_scale_batch_size=False, prepare_data_per_node=True, plugins=None, \
amp_backend='native', amp_level='O2', distributed_backend=None, \
automatic_optimization=True, move_metrics_to_cpu=False)
初始化训练器,参数很多,下面将分别介绍:
-
硬件参数:
- gpus[None]:
- 设置为0或None,表示使用cpu。
- 设置为大于0的整数n,表示使用n块gpu。
- 设置为大于0的整数字符串’n’,表示使用id为n的gpu。
- 设置为-1或’-1’,表示使用所有gpu。
- 设置为整数数组[a, b]或整数数组字符串’a, b’,表示使用id为a和b的gpu。
- auto_select_gpus[False]:
- 设置为True,自动选择所需gpu。
- 设置为False,按顺序选择所需gpu。
- num_nodes[1]:
- 设置为1,选择当前gpu节点。
- 设置为大于0的整数n,表示使用n个节点。
- tpu_cores[None]:
- 设置为None,表示不使用tpu。
- 设置为1,表示使用1个tpu内核。
- 设置为大于0的整数数组[n],表示使用id为n的tpu内核。
- 设置为8,表示使用所有tpu内核。
- gpus[None]:
-
精度参数:
- precision[32]:设置为2、4、8、16或32,分别表示不同的精度。
- amp_backend[“native”]:
- 设置为"native",表示使用本地混合精度。
- 设置为"apex",表示使用apex混合精度。
- amp_level[“O2”]: 设置为O0、O1、O2或O3,分别表示:
- O0:纯FP32训练,可以作为accuracy的baseline。
- O1:混合精度训练(推荐使用),根据黑白名单自动决定使用FP16(GEMM, 卷积)还是FP32(Softmax)进行计算。
- O2:“几乎FP16”混合精度训练,不存在黑白名单,除了Batch norm,几乎都是用FP16计算。
- O3:纯FP16训练,很不稳定,但是可以作为speed的baseline。
-
训练超参:
- max_epochs[1000]: 最大训练轮数。
- min_epochs[1]: 最小训练轮数。
- max_steps[None]: 每轮最大训练步数。
- min_steps[None]: 每轮最小训练步数。
-
日志参数和检查点参数:
- checkpoint_callback[True]:
- 设置为True,自动进行检查点保存。
- 设置为False,不进行检查点保存。
- logger[TensorBoardLogger]: 设置log工具。False表示不使用logger。
- default_root_dir[os.getcwd()]: 默认的根目录,用于日志和检查点的保存。
- flush_logs_every_n_steps[100]: 多少步更新一次日志到磁盘。
- log_every_n_steps[50]: 多少步更新一次日志到内存。
- log_gpu_memory[None]:
- 设置为None,不记录gpu显存信息。
- 设置为"all",记录所有gpu显存信息。
- 设置为"min_max",记录gpu显存信息最值。
- check_val_every_n_epoch[1]: 多少轮验证一次。
- val_check_interval[1.0]:
- 设置为小数,表示取一定比例的验证集。
- 设置为整数,表示取一定数量的验证集。
- resume_from_checkpoint[None]: 检查点恢复,输入路径。
- progress_bar_refresh_rate[1]: 进度条的刷新率。
- weights_summary[“top”]:
- 设置为None,不输出模型信息。
- 设置为"top",输出模型简要信息。
- 设置为"full",输出模型所有信息。
- weights_save_path[os.getcwd()]:
- 权重的保存路径。
- checkpoint_callback[True]:
-
测试参数:
- num_sanity_val_steps[2]: 训练前检查多少批验证数据。
- fast_dev_run[False]: 一系列单元测试。
- reload_dataloaders_every_epoch[False]: 每一轮是否重新载入数据。
-
分布式参数:
-
accelerator[None]:
-
dp(DataParallel)是在同一计算机的GPU之间拆分批处理。
-
ddp(DistributedDataParallel)是每个节点上的每个GPU训练并同步梯度。TPU默认选项。
-
ddp_cpu(CPU上的DistributedDataParallel)与ddp相同,但不使用GPU。对于多节点CPU训练或单节点调试很有用。
-
ddp2是节点上的dp,节点间的ddp。
-
-
accumulate_grad_batches[1]: 多少批进行一次梯度累积。
-
sync_batchnorm[False]: 同步批处理,一般是在分布式多GPU时使用。
-
-
自动参数:
- automatic_optimization[True]: 是否开启自动优化。
- auto_scale_batch_size[None]: 是否自动寻找最大批大小。
- auto_lr_find[False]: 是否自动寻找最佳学习率。
-
确定性参数:
- benchmark[False]: 是否使用cudnn.benchmark。
- deterministic[False]: 是否开启确定性。
-
限制性参数和采样参数:
- gradient_clip_val[0.0]: 梯度裁剪。
- limit_train_batches[1.0]: 限制每轮的训练批次数量。
- limit_val_batches[1.0]: 限制每轮的验证批次数量。
- limit_test_batches[1.0]: 限制每轮的测试批次数量。
- overfit_batches[0.0]: 限制批次的重复数量。
- prepare_data_per_node[True]: 是否对每个结点准备数据。
- replace_sampler_ddp[True]: 是否启用自动添加分布式采样器的功能。
-
其他参数:
- callbacks[]: 好家伙,callback。
- process_position[0]: 对进度条进行有序处理。
- profiler[None]
- track_grad_norm[-1]
- truncated_bptt_steps[None]
-
fit(model, train_dataloader=None, val_dataloaders=None, datamodule=None)
开启训练。参数如下:- datamodule (Optional[LightningDataModule]) – 一个LightningDataModule实例。
- model (LightningModule) – 训练的模型。
- train_dataloader (Optional[DataLoader]) – 训练数据。
- val_dataloaders (Union[DataLoader, List[DataLoader], None]) – 验证数据。
-
test(model=None, test_dataloaders=None, ckpt_path=‘best’, verbose=True, datamodule=None)
开启测试。参数如下:- ckpt_path (Optional[str]) – best或者你最希望测试的检查点权重的路径,None使用最后的权重。
- datamodule (Optional[LightningDataModule]) – 一个LightningDataModule实例。
- model (Optional[LightningModule]) – 测试的模型。
- test_dataloaders (Union[DataLoader, List[DataLoader], None]) – 测试数据。
- verbose (bool) – 是否打印结果。
-
tune(model, train_dataloader=None, val_dataloaders=None, datamodule=None)
训练之前调整超参数。参数如下:- datamodule (Optional[LightningDataModule]) – 一个LightningDataModule实例。
- model (LightningModule) – 调整的模型。
- train_dataloader (Optional[DataLoader]) – 训练数据。
- val_dataloaders (Union[DataLoader, List[DataLoader], None]) – 验证数据。
(2)属性
- callback_metrics回调指标。举个例子:
def training_step(self, batch, batch_idx):
self.log('a_val', 2)
callback_metrics = trainer.callback_metricpythons
assert callback_metrics['a_val'] == 2
- current_epoch 当前轮数。举个例子:
def training_step(self, batch, batch_idx):
current_epoch = self.trainer.current_epoch
if current_epoch > 100:
# do something
pass
- logger 当前日志。举个例子:
def training_step(self, batch, batch_idx):
logger = self.trainer.logger
tensorboard = logger.experiment
- logged_metrics 发送到日志的指标。举个例子:
def training_step(self, batch, batch_idx):
self.log('a_val', 2, log=True)
logged_metrics = trainer.logged_metrics
assert logged_metrics['a_val'] == 2
- log_dir 当前目录,用于保存图像等。举个例子:
def training_step(self, batch, batch_idx):
img = ...
save_img(img, self.trainer.log_dir)
-
is_global_zero 是否为全局第一个。
-
progress_bar_metrics 发送到进度条的指标。举个例子:
def training_step(self, batch, batch_idx):
self.log('a_val', 2, prog_bar=True)
progress_bar_metrics = trainer.progress_bar_metrics
assert progress_bar_metrics['a_val'] == 2
4.callback
- Pytorch Lightning最nb的插件,万能,无敌,随处可插,即插即用。
(1)训练方法
-
on_train_start(trainer, pl_module)
当第一次训练开始时的操作。 -
on_train_end(trainer, pl_module)
当最后一次训练结束时的操作。 -
on_train_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
当一批数据训练开始时的操作。 -
on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
当一批数据训练结束时的操作。 -
on_train_epoch_start(trainer, pl_module)
当一轮数据训练开始时的操作。 -
on_train_epoch_end(trainer, pl_module, outputs)
当一轮数据训练结束时的操作。
(2)验证方法
-
on_validation_start(trainer, pl_module)
当第一次验证开始时的操作。 -
on_validation_end(self, trainer, pl_module)
当最后一次验证结束时的操作。 -
on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
当一批数据验证开始时的操作。 -
on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
当一批数据验证结束时的操作。 -
on_validation_epoch_start(trainer, pl_module)
当一轮数据验证开始时的操作。 -
on_validation_epoch_end(trainer, pl_module)
当一轮数据验证结束时的操作。
(3)测试方法
-
on_test_start(trainer, pl_module)
当第一次测试开始时的操作。 -
on_test_end(self, trainer, pl_module)
当最后一次测试结束时的操作。 -
on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)
当一批数据测试开始时的操作。 -
on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)
当一批数据测试结束时的操作。 -
on_test_epoch_start(trainer, pl_module)
当一轮数据测试开始时的操作。 -
on_test_epoch_end(trainer, pl_module)
当一轮数据测试结束时的操作。
(4)其他方法
-
on_fit_start(trainer, pl_module)
当调用.fit时的操作。 -
on_fit_end(trainer, pl_module)
.fit结束时的操作。
setup(trainer, pl_module, stage)
teardown(trainer, pl_module, stage)
on_init_start(trainer)
on_init_end(trainer)
on_sanity_check_start(trainer, pl_module)
on_sanity_check_end(trainer, pl_module)
on_batch_start(trainer, pl_module)
on_batch_end(trainer, pl_module)
on_epoch_start(trainer, pl_module)
on_epoch_end(trainer, pl_module)
on_keyboard_interrupt(trainer, pl_module)
on_save_checkpoint(trainer, pl_module)
on_load_checkpoint(checkpointed_state)