pytorch训练的.pth模型格式转换

深度网络模型再工程化调用时存在模型不识别的问题,主要时推理的程序有不同的语言编写的,本文介绍使用pytorch训练的模型格式.pth转化为.pt格式的代码。

pytorch网络代码地址:https://github.com/yassouali/pytorch_segmentation ,安装请参考《pytorch框架下语义分割训练实践(一)》。

1 .pth模型转换为.pt模型

我是用fcn训练得到一个语言分割模型,checkpoint-epoch100.pth,使用inference.py文件可以正常调用,但用c++去调用的却不能正常载入。

使用下面的python脚本将.pth模型转换为.pt格式。

import torch
import torchvision
from models import fcn
 
model=torchvision.models.vgg16()
state_dict = torch.load("./checkpoint-epoch100.pth")
#print(state_dict)
model.load_state_dict(state_dict,False)
model.eval()
 
x = torch.rand(1,3,128,128)
ts = torch.jit.trace(model, x)
ts.save('fcn_vgg16.net')

注意很多人在转换的时候报错是因为:model.load_state_dict(state_dict)后面没用False参数,如下图所示。

2. .pth模型转化为.onnx模型

如需使用opencv来加载模型,则需将.pth转化为.onnx格式的模型。a.先安装onnx,使用命令:pip install onnx;b.使用以下命令转为.onnx模型

import io
import torch
import torch.onnx
import torchvision
from models import fcn

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def test():
    model=torchvision.models.vgg16()
 
    pthfile = r'./checkpoint-epoch100.pth'
    loaded_model = torch.load(pthfile, map_location='cpu')
    # try:
    #     loaded_model.eval()
    # except AttributeError as error:
    #     print(error)

    #model.load_state_dict(loaded_model['state_dict'])
    # model = model.to(device)

    #data type nchw
    dummy_input1 = torch.randn(1, 3, 244, 244)
    # dummy_input2 = torch.randn(1, 3, 64, 64)
    # dummy_input3 = torch.randn(1, 3, 64, 64)
    input_names = [ "actual_input_1"]
    output_names = [ "output1" ]
    # torch.onnx.export(model, (dummy_input1, dummy_input2, dummy_input3), "C3AE.onnx", verbose=True, input_names=input_names, output_names=output_names)
    torch.onnx.export(model, dummy_input1, "fcn.onnx", verbose=True, input_names=input_names, output_names=output_names)

if __name__ == "__main__":
	test()

 

  • 3
    点赞
  • 42
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

三月微暖寻春笋

赠人玫瑰手有余香

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值