Pytorch模型转换到onnx模型代码如下:
import torch
import torch.nn as nn
import torch.onnx
import onnx
import os
from QualityNet import QualityNet
if __name__ == '__main__':
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_path = './models/pytorch/face_quality.pth'
state_dict = torch.load(model_path)
model = QualityNet()
model = nn.DataParallel(model)
model = QualityNet().to(device)
model.load_state_dict(state_dict)
model.eval()
onnx_path = './models/onnx/face_quality.onnx'
dummy_input = torch.ones(1, 3, 128, 128,)
dummy_input = dummy_input.to(device)
input_names = ["input"]
output_names = ["output"]
# export onnx model
torch.onnx.export(model, dummy_input, onnx_path, verbose=False, opset_version=9, input_names=input_names, output_names=output_names)
# load onnx model
onnx_model = onnx.load(onnx_path)
# check onnx model
onnx.checker.check_model(onnx_model)
运行后会出现如下错误:
Pytorch采用DataParallel进行多卡训练得到的模型文件直接转换到onnx模型会出现不支持的情况,原因是使用DataParallel进行多卡训练,模型文件中的键值对key值前面会多一个"modules.":
解决方法很简单,只需要去掉多余的"module."字段即可,重新创建一个OrderedDict,修改模型键值,然后将它载入模型,修改后的代码如下:
import torch
import torch.nn as nn
from collections import OrderedDict
import torch.onnx
import onnx
import os
from QualityNet import QualityNet
if __name__ == '__main__':
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model_path = './models/pytorch/face_quality.pth'
state_dict = torch.load(model_path)
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove "module."
new_state_dict[name] = v
model = QualityNet()
model = QualityNet().to(device)
model.load_state_dict(new_state_dict)
model.eval()
onnx_path = './models/onnx/face_quality.onnx'
dummy_input = torch.ones(1, 3, 128, 128,)
dummy_input = dummy_input.to(device)
input_names = ["input"]
output_names = ["output"]
# export onnx model
torch.onnx.export(model, dummy_input, onnx_path, verbose=False, opset_version=9, input_names=input_names, output_names=output_names)
# load onnx model
onnx_model = onnx.load(onnx_path)
# check onnx model
onnx.checker.check_model(onnx_model)