# 寒冬已至,昼短苦夜长。
一 简单展示一下我是怎么保存lightning
from pytorch_lightning.callbacks import ModelCheckpoint
# # 定义 ModelCheckpoint 回调
checkpoint_callback = ModelCheckpoint(
monitor='valid_f1', # 监控的指标,可以是训练中的任何指标
dirpath=f'logs/{suf}/', # 指定保存模型参数的目录
filename='model-{epoch:02d}-{valid_f1:.3f}', # 模型参数文件名的格式
save_top_k=3, # 保存最佳的模型参数
mode='max',
save_last=True,
save_weights_only=True, # 仅保存模型的权重参数
)
trainer = pl.Trainer(
benchmark=True,
accelerator="gpu",
logger=logger,
# devices=2,
max_epochs=nb_epochs,
# precision='16-mixed',
accumulate_grad_batches=8,
enable_checkpointing=True,
callbacks=[checkpoint_callback],#<------ 这里
)
我是怎么在绝望中前行的
a = torch.load('model-47.pth') #OrderedDict,pytorch
b = torch.load('model-24.ckpt') #OrderedDict,lightning
# 查看
for i in a:
print(i)
break
------
output:Convolution.0.weight
======
for j in b:
print(j)
-------
epoch
global_step
pytorch-lightning_version
state_dict<------ 这里
loops
======
for j in b['state_dict']:
print(j)
------
output:net.Convolution.0.weight
“net.”是由于我定义pl模型的时候弄的,改模型太累,改字典很轻松
# 创建一个新的模型状态字典
new_state_dict = {}
# 遍历b中的键
for b_key, b_value in c['state_dict'].items():
# 删除键中的'net.'部分
new_key = b_key.replace('net.', '')
# 使用新的键来保存值
new_state_dict[new_key] = b_value
# 将新的状态字典保存到一个新的文件
torch.save(new_state_dict, 'model_20.pt')
# 也可以直接导入模型
network.load_state_dict('new_state_dict')
注:苦寻无果后,自己倒腾出来的。写文章不易,请点赞关注谢谢。