pytorch转onnx模型。
pytorch模型用DDP训练,模型的权值的名称前有module.
,需要去掉
new_checkpoint = {}
# print(checkpoint)
for k,value in checkpoint.items():
key = k.split('module.')[-1]
new_checkpoint[key] = value
print(k,key)
# model_pos.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos.load_state_dict(new_checkpoint, strict=True)
完整代码。
包含验证pytorch与转完后onnx模型的输出是否一致
import torch
import onnx
import onnxruntime as rt
from utils import *
chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
print('Loading checkpoint', chk_filename)
checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage)
checkpoint = checkpoint['model_pos']
new_checkpoint = {}
# print(checkpoint)
for k,value in checkpoint.items():
key = k.split('module.')[-1]
new_checkpoint[key] = value
print(k,key)
# model_pos.load_state_dict(checkpoint['model_pos'], strict=True)
model_pos.load_state_dict(new_checkpoint, strict=True)
output_file = './pf_f_27_1s.onnx'
test_instance = torch.rand((1, 27, 17, 2))
if True:
print('export .........')
torch.onnx.export(model_pos.cpu(),test_instance ,output_file ,
input_names=['input'], output_names=["output"], opset_version=10)
print('Finished ******************')
# check by onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
# check the numerical value
# get pytorch output
pytorch_results = model_pos(test_instance)
if not isinstance(pytorch_results, (list, tuple)):
assert isinstance(pytorch_results, torch.Tensor)
pytorch_results = [pytorch_results]
# get onnx output
input_all = [node.name for node in onnx_model.graph.input]
input_initializer = [
node.name for node in onnx_model.graph.initializer
]
net_feed_input = list(set(input_all) - set(input_initializer))
assert len(net_feed_input) == 1
sess = rt.InferenceSession(output_file)
onnx_results = sess.run(None,
{net_feed_input[0]: test_instance.detach().numpy()})
# compare results
assert len(pytorch_results) == len(onnx_results)
for pt_result, onnx_result in zip(pytorch_results, onnx_results):
assert np.allclose(
pt_result.detach().cpu(), onnx_result, atol=1.e-5
), 'The outputs are different between Pytorch and ONNX'
print('The numerical values are same between Pytorch and ONNX')