Pytorch Lightning 1.5.1 到 Pytorch Lightning 2.0.0 的迁移坑
问题描述
- 版本兼容报错,可参考函数替代与参数修改节,以及直接根据终端输出修改。
- 由 all_step_outputs 迁移修改至分步的 training_step_outputs, validation_step_outputs 和 test_step_outputs。结合官方 github 中迁移说明和 Pytoch Lighning 官方文档修改(😇)。
函数替代与参数修改(部分)
Pytorch Lightning 1.5.1 LightningModule
中的 training_epoch_end( ) 在Pytorch Lightning 2.0.0 LightningModule
中被替换为 on_train_epoch_end( )。类似的有 validation_epoch_end ( ) 及 test_epoch_end( )。原文参考:github 迁移。pytorch_lightning.trainer.trainer
参数发生以下变化:
① accelerator (Union[str, Accelerator, None]) --> accelerator (Union[str, Accelerator]),不再支持将 strategy 参数在 accelerator 中导入。
② checkpoint_callback (Optional[bool]) 参数被移除,该参数功能被 enable_checkpointing (Optional[bool]) 合并替代。
③ flush_logs_every_n_steps (Optional[int]) 和 log_gpu_memory (Optional[str]) 被移除,需用户在相应模块中自定义。
上述参考来源:Pytorch Lightning 1.5.1 官方文档 Trainer
All_step_outputs 的迁移
在官方 github 的说明中,给出了函数替代的修改,以及在 _init_
模块里新增 train_step_outputs 等变量定义,其修改如下:
class MyLightningModule(L.LightningModule):
+ def __init__(self):
+ super().__init__()
+ self.training_step_outputs = []
def training_step(self, ...):
loss = ...
+ self.training_step_outputs.append(loss)
return loss
#-------------------- 删除部分 ----------------------
#- def training_epoch_end(self, outputs):
#- epoch_average = torch.stack([output["loss"] for output in outputs]).mean()
#-------------------- 删除部分 ----------------------
+ def on_train_epoch_end(self):
+ epoch_average = torch.stack(self.training_step_outputs).mean()
self.log("training_epoch_average", epoch_average)
+ self.training_step_outputs.clear() # free memory
在Pytorch Lightning 2.0.0
官方文档 中 给出的 on_train_epoch_end( ) 简单模板:
class MyLightningModule(L.LightningModule):
def __init__(self):
super().__init__()
self.training_step_outputs = []
def training_step(self):
loss = ...
self.training_step_outputs.append(loss)
return loss
class MyCallback(L.Callback):
def on_train_epoch_end(self, trainer, pl_module):
# do something with all training_step outputs, for example:
epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
pl_module.log("training_epoch_mean", epoch_mean)
# free up the memory
pl_module.training_step_outputs.clear()
在 training_step 的末尾一定要加上 self.xxx.append( ) 操作,在 1.5.1
的版本中,这个操作似乎隐式的完成了,而在新版中如果不 append 在进行 evaluate 等操作时常常会得到空的输出。
validation_step_outputs() 与 test_step_outputs() 类似。
【踩坑ing】