pytorch框架自带onnx模型转换代码,如下使用resnet34和mobilenetV2为例
resnet34:
首先下载resnet34的pytorch模型,各版本resnet模型下载地址如下
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
下载完后使用如下代码进行模型转换:
import torch
def main():
model = torch.load("resnet34.pth")
imput = torch.randn(1,3,224,224)
torch.onnx.export(model, input, "resnet34.onnx", verbose=True, opset_version=11, export_params=True)
if __name__ == '__main__':
main()
mobilenetV2:
各类轻量化模型的下载地址如下:
下载完后使用如下代码进行模型转换:
import torch
from light_cnns import mbv1
def main():
model = mbv1()
imput = torch.randn(1,3,224,224)
torch.onnx.export(model, input, "mobilenetV2.onnx", verbose=True, opset_version=11, export_params=True)
if __name__ == '__main__':
main()