pytorch模型转onnx

pt转onnx

Support Matrix
torch.onnx
onnx上查看转换方法

开始转换

  1. 加载pt文件

torchvision支持的model 地址

model = torchvision.models.xxx()
model .load_state_dict(torch.load("xxx.pth"))
model.eval()

不支持的,以resnet50_nfc为例

model = get_model("resnet50", num_label, use_id=False, num_id=num_id)  ##自定义的网络
model .load_state_dict(torch.load("xxx.pth"))
model.eval()
  1. onnx导出
##准备导出
dummy_input = torch.randn(3, 3, 224, 224, device='cpu')
##输入和输出的名字
input_names = ["actual_input_1"] + ["learned_%d" % i for i in range(1)]
output_names = ["output1"]
##导出为rr.onnx
torch.onnx.export(model, dummy_input, "rr.onnx", verbose=True, input_names=input_names, output_names=output_names)
  1. 验证

torch.rand 和 numpy.rand 互转

 x = torch.rand(2,2)
 x1 = x.numpy() # torch转换到numpy
 x2 = torch.from_numpy(x1) #numpy转换torch

开始验证

import onnxruntime as ort
import numpy as np
##导入onnx文件
ort_session = ort.InferenceSession('rr.onnx')
# outputs = ort_session.run(None, {'actual_input_1': np.random.randn(3, 3, 224, 224).astype(np.float32)})
##dummy_input为torch的随机矩阵,转为numpy的
outputs = ort_session.run(None, {'actual_input_1': dummy_input.numpy()})
##打印
print(model.forward(dummy_input))
print(outputs[0])

运行结果
结果一样

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值