如何将pytorch-lightning训练的结果ckpt转为pytorch的pt?

# 寒冬已至,昼短苦夜长。

一 简单展示一下我是怎么保存lightning

from pytorch_lightning.callbacks import ModelCheckpoint

# # 定义 ModelCheckpoint 回调
checkpoint_callback = ModelCheckpoint(
    monitor='valid_f1',  # 监控的指标,可以是训练中的任何指标
    dirpath=f'logs/{suf}/',  # 指定保存模型参数的目录
    filename='model-{epoch:02d}-{valid_f1:.3f}',  # 模型参数文件名的格式
    save_top_k=3,  # 保存最佳的模型参数
    mode='max',
    save_last=True,
    save_weights_only=True, # 仅保存模型的权重参数
)
trainer = pl.Trainer(
    benchmark=True,
    accelerator="gpu",
    logger=logger,
    # devices=2,
    max_epochs=nb_epochs,
    # precision='16-mixed',
    accumulate_grad_batches=8,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],#<------ 这里
)

我是怎么在绝望中前行的

a = torch.load('model-47.pth') #OrderedDict,pytorch
b = torch.load('model-24.ckpt') #OrderedDict,lightning

# 查看
for i in a:
    print(i)
    break
------
output:Convolution.0.weight
======
for j in b:
    print(j)
-------
epoch
global_step
pytorch-lightning_version
state_dict<------ 这里
loops
======
for j in b['state_dict']:
    print(j)
------
output:net.Convolution.0.weight

“net.”是由于我定义pl模型的时候弄的,改模型太累,改字典很轻松

# 创建一个新的模型状态字典
new_state_dict = {}

# 遍历b中的键
for b_key, b_value in c['state_dict'].items():
    # 删除键中的'net.'部分
    new_key = b_key.replace('net.', '')
    
    # 使用新的键来保存值
    new_state_dict[new_key] = b_value

# 将新的状态字典保存到一个新的文件
torch.save(new_state_dict, 'model_20.pt')

# 也可以直接导入模型
network.load_state_dict('new_state_dict')

注:苦寻无果后,自己倒腾出来的。写文章不易,请点赞关注谢谢。

如何将pytorch-lightning训练的结果保存为xxx.pt? - 知乎

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值