unet++ pytorch模型转换为onnx模型并实际测试

5 篇文章 0 订阅
2 篇文章 0 订阅

书接上回,上次在安装好openvino环境之后,以及自己在了解完其相关的处理流程之后,现在将自己的模型转换为onnx格式以便后续转换为openvino的中间件。
直接上代码:

import os
import cv2
import onnxruntime
import torch
from albumentations import Compose
from albumentations.augmentations import transforms
from torch.utils.data import DataLoader

import L1_archs_cut
from dataset import test_Dataset


def pth_2onnx():
    """
    pytorch 模型转换为onnx模型
    :return:
    """
    torch_model = torch.load('./model/model.pth')

    model = L1_archs_cut.NestedUNet(num_classes=1, input_channels=3, deep_supervision=True)
    model.load_state_dict(torch_model)
    batch_size = 1  # 批处理大小
    input_shape = (3, 1920, 1088)  # 输入数据

    # set the model to inference mode
    model.eval()
    print(model)
    x = torch.randn(batch_size, *input_shape)  # 生成张量
    export_onnx_file = "model.onnx"  # 目的ONNX文件名
    torch.onnx.export(model,
                      x,
                      export_onnx_file,
                      # 注意这个地方版本选择为11
                      opset_version=11,
                      do_constant_folding=True,  # 是否执行常量折叠优化
                      input_names=["input"],  # 输入名
                      output_names=["output"],  # 输出名
                      dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
                                    "output": {0: "batch_size"}})


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


def image_test():
    """
    实际测试onnx模型效果
    :return:
    """
    onnx_path = './model/model.onnx'
    image_path = './data/ray_test/'

    test_transform = Compose([
        transforms.Normalize(),
    ])

    test_dataset = test_Dataset(
        img_ids='0',
        img_dir=image_path,
        num_classes=1,
        transform=test_transform
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=1,
        drop_last=False
    )
    # print(test_loader)
    for input, meta in test_loader:
        ort_session = onnxruntime.InferenceSession(onnx_path)
        # print('input', input.shape)
        # print(input.shape)
        ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(input)}
        # print('ort_inputs', len(ort_inputs))
        ort_outs = ort_session.run(None, ort_inputs)
        # print('ort_outs', type(ort_outs))
        img_out = ort_outs[0]
        img_out = torch.from_numpy(img_out)
        # print('1', img_out)
        img_out = torch.sigmoid(img_out).cpu().numpy()

        # print('img_out', img_out.shape)
        img_out = img_out.transpose(0, 1, 3, 2)
        num_classes = 1
        for i in range(len(img_out)):
            cv2.imwrite(os.path.join('./', meta['img_id'][i].split('.')[0] + '.png'),
                        (img_out[i, num_classes - 1] * 255).astype('uint8'))


if __name__ == '__main__':
    # pth_2onnx()
    image_test()

数据加载的模块:

import os

import cv2
import torch.utils.data
class test_Dataset(torch.utils.data.Dataset):
    def __init__(self, img_ids, img_dir, num_classes, transform=None):
        self.img_ids = img_ids
        self.img_dir = img_dir
        self.num_classes = num_classes
        self.transform = transform

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):

        img_id = self.img_ids[idx]

        img = cv2.imread(os.path.join(self.img_dir, img_id + '.png'))
        if self.transform is not None:
            augmented = self.transform(image=img)
            img = augmented['image']

        img = img.astype('float32') / 255

        img = img.transpose(2, 1, 0)

        return img, {'img_id': img_id}

以上代码是经过自己的测试转换为onnx格式的unet++模型和pth格式的效果是一样的。

注:
若在转换为中间件的过程中若出现一下错误:

[ ERROR ]  Exception occurred during running replacer "REPLACEMENT_ID" (<class 'extensions.middle.DecomposeBias.DecomposeBias'>): After partial shape inference were found shape collision for node Conv_0 (old shape: [   0   32 1920 1088], new shape: [  -1   32 1920 1088])

可将onnx转化的

  do_constant_folding=True,  # 是否执行常量折叠优化
              input_names=["input"],  # 输入名
              output_names=["output"],  # 输出名
              dynamic_axes={"input": {0: "batch_size"},  # 批处理变量
                            "output": {0: "batch_size"}}

这部分代码删除掉
江湖不是打打杀杀,是人情世故,若你有幸看到这边文章点个关注可否

  • 6
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值