DETR导出onnx模型,并进行推理(cpu环境)

        首先在detr项目目录下创建onnx文件夹,用于存放detr的pth文件,后续导出的onnx文件也存放在此。

        在detr项目目录下创建export_onnx.py文件,将下面代码拷贝之后直接运行即可导出detr.onnx模型,onnx模型存放到onnx文件夹下。

import io
import argparse
import onnx
import onnxruntime
import torch
from hubconf import detr_resnet50


class ONNXExporter:
    @classmethod
    def setUpClass(cls):
        torch.manual_seed(123)

    def run_model(self, model, onnx_path, inputs_list, tolerate_small_mismatch=False,
                  do_constant_folding=True,
                  output_names=None, input_names=None):
        model.eval()

        onnx_io = io.BytesIO()
        onnx_path = onnx_path

        torch.onnx.export(model, inputs_list[0], onnx_io,
                          input_names=input_names, output_names=output_names, export_params=True, training=False,
                          opset_version=12,do_constant_folding=do_constant_folding)
        torch.onnx.export(model, inputs_list[0], onnx_path,
                          input_names=input_names, output_names=output_names, export_params=True, training=False,
                          opset_version=12,do_constant_folding=do_constant_folding)

        print(f"[INFO] ONNX model export success! save path: {onnx_path}")

        # validate the exported model with onnx runtime
        for test_inputs in inputs_list:
            with torch.no_grad():
                if isinstance(test_inputs, torch.Tensor) or isinstance(test_inputs, list):
                    # test_inputs = (nested_tensor_from_tensor_list(test_inputs),)
                    test_inputs = (test_inputs,)
                test_ouputs = model(*test_inputs)
                if isinstance(test_ouputs, torch.Tensor):
                    test_ouputs = (test_ouputs,)
            self.ort_validate(onnx_io, test_inputs, test_ouputs, tolerate_small_mismatch)


    def ort_validate(self, onnx_io, inputs, outputs, tolerate_small_mismatch=False):

        inputs, _ = torch.jit._flatten(inputs)
        outputs, _ = torch.jit._flatten(outputs)

        def to_numpy(tensor):
            if tensor.requires_grad:
                return tensor.detach().cpu().numpy()
            else:
                return tensor.cpu().numpy()

        inputs = list(map(to_numpy, inputs))
        outputs = list(map(to_numpy, outputs))

        ort_session = onnxruntime.InferenceSession(onnx_io.getvalue())
        # compute onnxruntime output prediction
        ort_inputs = dict((ort_session.get_inputs()[i].name, inpt) for i, inpt in enumerate(inputs))
        ort_outs = ort_session.run(None, ort_inputs)
        for i in range(0, len(outputs)):
            try:
                torch.testing.assert_allclose(outputs[i], ort_outs[i], rtol=1e-03, atol=1e-05)
            except AssertionError as error:
                if tolerate_small_mismatch:
                    print(error)
                else:
                    raise

    @staticmethod
    def check_onnx(onnx_path):
        model = onnx.load(onnx_path)
        onnx.checker.check_model(model)
        print(f"[INFO]  ONNX model: {onnx_path} check success!")



if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='DETR Model to ONNX Model')
    # detr pth 模型存放的路径
    parser.add_argument('--model_dir', type=str, default='onnx/detr-r50-e632da11.pth',
                        help='DETR Pytorch Model Saved Dir')
    parser.add_argument('--check', default=True, action="store_true", help='Check Your ONNX Model')
    # pth转换onnx后存放的路径
    parser.add_argument('--onnx_dir', type=str, default="onnx/detr.onnx", help="Check ONNX Model's dir")
    parser.add_argument('--batch_size', type=int, default=1, help="Batch Size")

    args = parser.parse_args()

    # load torch model
    detr = detr_resnet50(pretrained=False, num_classes=90 + 1).eval()  # max label index add 1
    # state_dict = torch.load(args.model_dir, map_location='cuda')  # model path
    state_dict = torch.load(args.model_dir, map_location='cpu')  # model path
    detr.load_state_dict(state_dict["model"])

    # dummy input
    dummy_image = [torch.ones(args.batch_size, 3, 800, 800)]

    # to onnx
    onnx_export = ONNXExporter()
    onnx_export.run_model(detr, args.onnx_dir, dummy_image, input_names=['inputs'], 
                          output_names=["pred_logits", "pred_boxes"], tolerate_small_mismatch=True)

    # check onnx model
    if args.check:
        ONNXExporter.check_onnx(args.onnx_dir)

        导出的时候可能会提示警告:

        无视就好,稍等一两分钟就可以完成onnx的导出。 

        导出后,在同级目录下创建inference_onnx.py文件,使用刚才导出的onnx模型进行预测。

import cv2
from PIL import Image
import numpy as np
import os
import random

try:
    import onnxruntime
except ImportError:
    onnxruntime = None

import torch
import torchvision.transforms as T

torch.set_grad_enabled(False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = T.Compose([
    T.Resize((800, 800)),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])


def box_cxcywh_to_xyxy(x):
    x = torch.from_numpy(x)
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)


def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b.cpu().numpy()
    b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
    return b


def plot_one_box(x, img, color=None, label=None, line_thickness=1):
    tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1  # line/font thickness
    color = color or [random.randint(0, 255) for _ in range(3)]
    c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
    cv2.rectangle(img, c1, c2, color, thickness=tl, lineType=cv2.LINE_AA)
    if label:
        tf = max(tl - 1, 1)  # font thickness
        t_size = cv2.getTextSize(label, 0, fontScale=tl / 3, thickness=tf)[0]
        c2 = c1[0] + t_size[0], c1[1] - t_size[1] - 3
        cv2.rectangle(img, c1, c2, color, -1, cv2.LINE_AA)  # filled
        cv2.putText(img, label, (c1[0], c1[1] - 2), 0, tl / 3, [225, 255, 255], thickness=tf, lineType=cv2.LINE_AA)


CLASSES = [
    'N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
    'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
    'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
    'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
    'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
    'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
    'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
    'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
    'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
    'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
    'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]


def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False):
    cv2Image = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)

    for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
        cl = p.argmax()
        label_text = '{} {}%'.format(CLASSES[cl], round(p[cl] * 100, 2))
        plot_one_box((xmin, ymin, xmax, ymax), cv2Image, label=label_text)

    if imshow:
        cv2.imshow('detect', cv2Image)
        cv2.waitKey(0)

    if imwrite:
        if not os.path.exists("onnx/result"):
            os.makedirs('onnx/result')
        cv2.imwrite('onnx/result/{}'.format(save_name), cv2Image)


def detect_onnx(ort_session, im, prob_threshold=0.7):
    img = transform(im).unsqueeze(0).cpu().numpy()
    ort_inputs = {"inputs": img}
    scores, boxs = ort_session.run(None, ort_inputs)
    probas = torch.from_numpy(np.array(scores)).softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold
    probas = probas.cpu().detach().numpy()
    keep = keep.cpu().detach().numpy()
    bboxes_scaled = rescale_bboxes(boxs[0, keep], im.size)
    return probas[keep], bboxes_scaled


if __name__ == "__main__":
    onnx_path = "onnx/detr.onnx"
    ort_session = onnxruntime.InferenceSession(onnx_path)
    files = os.listdir("onnx/images")

    for file in files:
        img_path = os.path.join("onnx/images", file)
        im = Image.open(img_path)
        scores, boxes = detect_onnx(ort_session, im)
        plot_result(im, scores, boxes, save_name=file, imshow=False, imwrite=True)


预测结果:

直接用pth进行推理的可以看: DETR推理代码_athrunsunny的博客-CSDN博客

  • 11
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

athrunsunny

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值