问题起源
由于pl保存的是ckpt文件,当我想加载一个ckpt,用模型做推理,但我同时需要一些ckpt里的信息,例如当前的epoch。通过Trainer加载模型后,有个Trainer.current_epoch,欣喜若狂,然而打印出来却是0;或者nn_model.load_from_checkpoint,也有这个nn_model.current_epoch,结果都是0。
顺着Trainer的代码翻到了一句说明:current_epoch会在继续训练的循环开始后才更新。可我就想在你循环开始之前就用啊。
现在国内的pl讨论太少了,一搜全是搬运b乎的几个大佬的教程;至于一些其他的小问题,几乎没有讨论。
解决办法
功力深厚的巨佬可能根本不会遇到这个问题,因为他们懂torch怎么加载ckpt,查看ckpt里的参数。可我既然用了pl,我就必须用pl的方式打开。
去官网的doc搜,直接上链接。
确实,pl也是用的torch。
上代码:
import pytorch_lightning as pl
from pytorch_lightning.plugins.io import TorchCheckpointIO as tcio
# 实例化自己的model
nn_model = A()
ckpt_path = 'abc.ckpt'`在这里插入代码片`
trainer = pl.Trainer(resume_from_checkpoint=ckpt_path)
# 实例化函数
tc = tcio()
ckpt_dict = tc.load_checkpoint(path=ckpt_path)
# 返回的是字典,内容相当丰富。