转换代码
注意点:要根据你的代码进行修改,修改最初的包等
import torch
from models.with_mobilenet import PoseEstimationWithMobileNet
from modules.load_state import load_state
from action_detect.net import NetV2
def convert_onnx():
print('start!!!')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#model_path = '/home/pi/xg_openpose_fall_detect-master/weights/checkpoint_iter_370000.pth' #这是我们要转换的模型
model = PoseEstimationWithMobileNet().to(device)
checkpoint = torch.load(r'E:/xg_openpose_fall_detect-master/weights/checkpoint_iter_370000.pth', map_location='cpu')
load_state(model, checkpoint)
model.to(device)
model.eval()
dummy_input = torch.randn(1,3,256,456).to(device)#输入大小 #data type nchw
onnx_path = 'E:/xg_openpose_fall_detect-master/weights/openpose.onnx'
print("----- pth导出为onnx模型 -----")
output_name = "openpose.onnx"
torch.onnx.export(model, dummy_input, onnx_path, export_params=True,input_names=['input'], output_names=['output'])
print('convert retinaface to onnx finish!!!')
if __name__ == "__main__" :
convert_onnx()
https://github.com/openvinotoolkit/training_extensions/tree/develop/misc/pytorch_toolkit/human_pose_estimation 官方
import argparse
import torch
from models.with_mobilenet import PoseEstimationWithMobileNet
#from models.single_person_pose_with_mobilenet import SinglePersonPoseEstimationWithMobileNet
from modules.load_state import load_state
def convert_to_onnx(net, output_name, single_person, input_size):
input = torch.randn(1, 3, input_size[0], input_size[1])
input_layer_names = ['data']
output_layer_names = ['stage_0_output_1_heatmaps', 'stage_0_output_0_pafs',
'stage_1_output_1_heatmaps', 'stage_1_output_0_pafs']
if single_person:
input = torch.randn(1, 3, input_size[0], input_size[1])
output_layer_names = ['stage_{}_output_1_heatmaps'.format(i) for i in range(len(net.refinement_stages) + 1)]
torch.onnx.export(net, input, output_name, verbose=True, input_names=input_layer_names,
output_names=output_layer_names)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint-path', type=str, required=True, help='path to the checkpoint')
parser.add_argument('--output-name', type=str, default='human-pose-estimation.onnx',
help='name of output model in ONNX format')
parser.add_argument('--single-person', action='store_true', help='convert model for single-person pose estimation')
parser.add_argument('--input-size', nargs='+', type=int, required=True,
help='Size of input image in format: height width')
parser.add_argument('--mode-interpolation', type=str, required=False, default='bilinear',
help='type interpolation <bilinear> or <nearest>')
parser.add_argument('--num-refinement-stages', type=int, default=1, help='number of refinement stages')
args = parser.parse_args()
net = PoseEstimationWithMobileNet()
'''
if args.single_person:
net = SinglePersonPoseEstimationWithMobileNet(mode=args.mode_interpolation, num_refinement_stages=args.num_refinement_stages)
'''
checkpoint = torch.load(args.checkpoint_path)
#--checkpoint-path checkpoint_iter_370000.pth --input-size 256 456
load_state(net, checkpoint)
convert_to_onnx(net, args.output_name, args.single_person, args.input_size)
python scripts/convert_to_onnx.py --checkpoint-path <CHECKPOINT>
例如: python scripts/convert_to_onnx.py --checkpoint-path checkpoint_iter_370000.pth --input-size 256 456