import torch from torchvision import models def mytranslate_pth_onnx(save_weight_pth, num=1, channel=3, height=512, width=672): """ Torch Version 1.8.2+cu101 :param save_weight_pth: 权重文件 :param num: 1 :param channel: 3 :param height: 512 :param width: 672 :return: """ print('Torch Version', torch.__version__) print('格式转换中...') # 模拟数据 N C H W example = torch.rand(num, channel, height, width) # 网络模型 model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True) checkpoint = torch.load(save_weight_pth, map_location=lambda storage, loc: storage.cuda(0)) model.load_state_dict(checkpoint) # 导出ONNX model.eval() torch.onnx.export(model, example, save_weight_pth + ".onnx", verbose=True, opset_version=11) print("格式转换完成!") if __name__ == "__main__": # 权重文件 save_weight_pth = r'C:\Users\admin\.cache\torch\hub\checkpoints\fasterrcnn_resnet50_fpn_coco-258fb6c6.pth' num = 1 channel = 3 height = 512 width = 672 mytranslate_pth_onnx(save_weight_pth, num=1, channel=3, height=512, width=672)
pytorch faster rcnn 权重文件 导出 ONNX 格式
最新推荐文章于 2023-07-05 11:17:07 发布