torch onnx动态输入、动态输出

代码如下:

# Some standard imports
import io
import numpy as np

from torch import nn
import torch.utils.model_zoo as model_zoo
import torch.onnx
# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import onnxruntime


class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
        # self.upsample = nn.Upsample(scale_factor=upscale_factor, mode='nearest')
        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))  # torch.Size([1, 32, 224, 224])
        x = self.conv4(x)  # torch.Size([1, 9, 224, 224]) 224*3=672
        x = self.pixel_shuffle(x)  # torch.Size([1, 1, 672, 672])
        # x = self.upsample(x)
        x = F.interpolate(x, scale_factor=0.5)
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)


# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

# Initialize model with the pretrained weights
map_location = lambda storage, loc: storage
if torch.cuda.is_available():
    map_location = None
pth_path = r"superres_epoch100-44c6958e.pth"
checkpoint = torch.load(pth_path, map_location=map_location)
torch_model.load_state_dict(checkpoint)

# set the model to inference mode
torch_model.eval()

# Input to the model
batch_size = 1  # just a random number
channel = 1
h_size = 224
w_size = 224
x = torch.randn(**batch_size, channel, h_size, w_size**, requires_grad=True)
torch_out = torch_model(x)

dynamic_axes = {'input': {**0: 'batch_size', 1: 'channel', 2: "height", 3: 'width'**},  # variable lenght axes
                'output': {**0: 'batch_size', 1: 'channel', 2: "height", 3: 'width'**}}
# Export the model
torch.onnx.export(torch_model,  # model being run
                  x,  # model input (or a tuple for multiple inputs)
                  "super_resolution.onnx",  # where to save the model (can be a file or file-like object)
                  export_params=True,  # store the trained parameter weights inside the model file
                  opset_version=11,  # the ONNX version to export the model to
                  do_constant_folding=True,  # whether to execute constant folding for optimization
                  input_names=['input'],  # the model's input names
                  output_names=['output'],  # the model's output names
                  dynamic_axes=dynamic_axes)

import onnx

onnx_model = onnx.load("super_resolution.onnx")
onnx.checker.check_model(onnx_model)


def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


ort_session = onnxruntime.InferenceSession("super_resolution.onnx")
ort_x = torch.randn(**2, 1, 321, 321**, requires_grad=True)
# compute ONNX Runtime output prediction
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(ort_x)}
ort_outs = ort_session.run(None, ort_inputs)

torch_out = torch_model(ort_x)
# compare ONNX Runtime and PyTorch results
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

对onnxruntime加载的模型进行观察:
在这里插入图片描述

参考:

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
https://github.com/shaoboc-lmf/mrcnn-onnx-export/blob/0e8d029b6b4173bd0e1c0685ba5b395ea9d48eb5/test_debug_onnx_rpn_head_concat_anchor_gen.py
  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值