使用torch.onnx.export()转换模型为onnx
import torch
import torchvision
from utils import get_network #此处是自己定义的获取模型结构的模块
net = get_network("mobilenet") #以Mobilenet为例
weight = "***.pth" #只保存了模型参数
net.load_state_dict(torch.load(weight))
#若.pth文件保存了模型结构和参数,则直接 net = torch.load("***.pth")
net.eval()
dummy_input = torch.randn(1, 3, 128, 128) #这个得根据自己网络输入的shape来
input_names = ["input0"] #命名随意取
output_names = ["output1"] #命名随意取
onnx_save = "mobilenet.onnx" #转换完模型保存路径
torch.onnx.export(net, dummy_input, onnx_save, verbose=True, input_names=input_names, output_names=output_names)