将pytorch的pth文件固化为pt文件

说明

我参考了一个开源的人像语义分割项目mobile_phone_human_matting,这个项目提供了预训练模型,我想要将该模型固化,然后转换格式后在嵌入式端使用。

该项目保存模型的代码如下:

lastest_out_path = "{}/ckpt_lastest.pth".format(self.save_dir_model)
        torch.save({
            'epoch': epoch,
            'state_dict': model.state_dict(),
        }, lastest_out_path)

转换代码

上面代码保存了state_dict, 所以保存的文件中是不含模型结构的,固化时需要从代码构造网络结构。好在项目是完全开源,将原项目下的model目录拷贝过来就行。
另外不能忘记调用eval() 来固化参数。

完整的转换代码如下:

import torch
from model import segnet

ckptfile="./ckpt_lastest.pth"
savedfile="./human_seg.pt"

model = segnet.SegMattingNet()    

device = torch.device('cpu')
ckpt = torch.load(ckptfile, map_location=device )
model.load_state_dict(ckpt['state_dict'])

model.eval() #这一步会将参数固化,不能省。否则会报AssertionError('batchnorm with training is not support. Please set model.eval() before export.')

x = torch.rand(1,3,256,256)
ts = torch.jit.trace(model, x)
ts.save(savedfile)

参考资料

mobile_phone_human_matting

pytorch训练的.pth模型格式转换

  • 2
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值