pytorch模型导出成ONNX格式:支持多参数与动态输入

pytorch格式的模型在部署之前一般需要做格式转换。本文介绍了如何将pytorch格式的模型导出到ONNX格式的模型。ONNX(Open Neural Network Exchange)格式是一种常用的开源神经网络格式,被较多推理引擎支持,比如:ONNXRuntime, Intel OpenVINO, TensorRT等。

1. 网络结构定义

我们以一个Image Super Resolution的模型为例。首先,需要知道模型的网络定义SuperResolutionNet,并创建模型对象torch_model

# Super Resolution model definition in PyTorch
import torch.nn as nn
import torch.nn.init as init

class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        init.zeros_(self.conv4.bias)

# Create the super-resolution model by using the above model definition.
torch_model = SuperResolutionNet(upscale_factor=3)

2. 加载模型文件

Pytorch模型的参数信息存储在state_dict中。 state_dict是一个Python字典结构的对象,里面存储了神经网络中每层对应的参数张量。将每层的参数结构以及最后一层的bias打印出来:

def print_state_dict(state_dict):    
    print(len(state_dict))
    for layer in state_dict:
        print(layer, '\t', state_dict[layer].shape)
    print(state_dict['conv4.bias'])
print_state_dict(model.state_dict())

输出:

8
conv1.weight 	 torch.Size([64, 1, 5, 5])
conv1.bias 	 torch.Size([64])
conv2.weight 	 torch.Size([64, 64, 3, 3])
conv2.bias 	 torch.Size([64])
conv3.weight 	 torch.Size([32, 64, 3, 3])
conv3.bias 	 torch.Size([32])
conv4.weight 	 torch.Size([9, 32, 3, 3])
conv4.bias 	 torch.Size([9])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0.])

因为之前将第四层的bias初始化为0,所以输出是全零。然后调用load_state_dict加载模型文件,可以看到加载之后参数的变化。eval将模型设置为推理状态。

# Load pretrained model weights
model_url = 'https://s3.amazonaws.com/pytorch/test_data/export/superres_epoch100-44c6958e.pth'

model.load_state_dict(model_zoo.load_url(model_url))

print_state_dict(model.state_dict())
# set the model to inference mode
model.eval()

输出

8
conv1.weight 	 torch.Size([64, 1, 5, 5])
conv1.bias 	 torch.Size([64])
conv2.weight 	 torch.Size([64, 64, 3, 3])
conv2.bias 	 torch.Size([64])
conv3.weight 	 torch.Size([32, 64, 3, 3])
conv3.bias 	 torch.Size([32])
conv4.weight 	 torch.Size([9, 32, 3, 3])
conv4.bias 	 torch.Size([9])
tensor([-0.0151, -0.0191, -0.0362, -0.0224,  0.0548,  0.0113,  0.0529,  0.0258,
        -0.0180])
SuperResolutionNet(
  (relu): ReLU()
  (conv1): Conv2d(1, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(32, 9, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (pixel_shuffle): PixelShuffle(upscale_factor=3)
)

3. 输出成ONNX格式

在调用torch.onnx.export之前,需要先创建输入数据。因为模型的导出实际上是执行了一次推理过程。在执行的过程中记录使用到的操作。输入数据可以是随机的:

# Input to the model
x = torch.randn(1, 1, 224, 224, requires_grad=True)
# Export the model
torch.onnx.export(model,               # model being run
                  x,                         # model input 
                  "D:\\super_resolution.onnx",   # where to save the model (can be a file or file-like object)                  
                  opset_version=11,          # the ONNX version to export the model to                  
                  input_names = ['input'],   # the model's input names
                  output_names = ['output']  # the model's output names
                  )

export的第一个参数是模型对象,第二个参数是输入数据,第三个参数是输出的模型文件名,这三个参数是必须指定的。还有一些常用的可选参数:

  • opset_version, 指定的操作版本,一般越高的版本会支持更多的操作。如果遇到某个操作不支持,可以将版本号设置的高一点试试。
  • input_names, 输入参数名。如果不指定,会使用默认名字。
  • output_names, 输出参数名。如果不知道,会使用默认名字。

输出成功后,可以使用Netron查看网络结构。Netron是一个开源的神经网络模型可视化工具,可以使用在线网页版的https://netron.app/,或者下载安装桌面版的https://github.com/lutzroeder/netron。打开导出的模型,结构如下:
在这里插入图片描述

4. 导出动态输入模型

可以看到上面导出的模型输入是固定的1 x 1 x 224 x 224输出是固定的1 x 1 x 672 x 672.实际应用的时候输入图片的尺寸是不固定的,而且可能一次输入多种图片一起处理。我们可以通过指定dynamic_axes参数来导出动态输入的模型。dynamic_axes的参数是一个字典类型,字典的key就是输入或者输出的名字,对应key的value可以是一个字典或者列表,指定了输入或者输出的index以及对应的名字。比如想要让输入的index为0的维度表示动态的batch_size那么就指定{0: 'batch_size'}。同样的方法可以指定宽高所在的维度输出成动态的。

input_name = 'input'
output_name = 'output'
torch.onnx.export(model,               # model being run
                  x,                         # model input 
                  "D:\\super_resolution_2.onnx",   # where to save the model (can be a file or file-like object)                  
                  opset_version=11,          # the ONNX version to export the model to                  
                  input_names = [input_name],   # the model's input names
                  output_names = [output_name],  # the model's output names
                  dynamic_axes= {
                        input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
                        output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
                  )

输出的模型使用Netron打开,结构如下:
在这里插入图片描述
查看输入输出信息可以看到,输入的维度变成:[batch_size,1,in_width,int_height],输出的维度变成:[batch_size,1,out_width,out_height]。表示这个模型可以接收动态的批次大小和宽高尺寸。
在这里插入图片描述

5. 多参数输入

5.1 多参数输入模型的导出

有时候可能会遇到比较复杂的模型,推理时需要输入多个参数的情况。我们可以通过将参数列表包在一个list中来输出ONNX模型。我们先将模型的forward方法修改一下,增加一个输入参数scale

class SuperResolutionNet2(nn.Module):
    def __init__(self, upscale_factor, inplace=False):
        super(SuperResolutionNet2, self).__init__()

        self.relu = nn.ReLU(inplace=inplace)
        self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
        self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
        self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
        self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        self._initialize_weights()

    def forward(self, x, scale):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.pixel_shuffle(self.conv4(x))
        return x

    def _initialize_weights(self):
        init.orthogonal_(self.conv1.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv2.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv3.weight, init.calculate_gain('relu'))
        init.orthogonal_(self.conv4.weight)
        init.zeros_(self.conv4.bias)  

# Create the super-resolution model by using the above model definition.
model2 = SuperResolutionNet2(upscale_factor=3)

调用export输出到ONNX:

input_name = 'input'
output_name = 'output'
torch.onnx.export(model2,               
                  (x, 2),                         
                  "D:\\super_resolution_3.onnx",   
                  opset_version=11,          
                  input_names = [input_name],  
                  output_names = [output_name],
                  dynamic_axes= {
                        input_name: {0: 'batch_size', 2 : 'in_width', 3: 'int_height'},
                        output_name: {0: 'batch_size', 2: 'out_width', 3:'out_height'}}
                  )

5.2 易错点

由于export函数的机制,会把模型输入的参数自动转换成tensor类型,比如上面的scale参数,虽然传入的时候是int32类型,但是export在执行时会调用到forward函数,此时scale已经变成一个tensor类型。我们可以做个测试,打印一下scale的类型来验证:

def forward(self, x, scale):
    print(scale)
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.relu(self.conv3(x))
    x = self.pixel_shuffle(self.conv4(x))
    return x

重新运行export后输出:

tensor(2)

这种机制带来的影响是,在使用scale参数时可能需要做一个转换,比如转换成float类型。否则某些函数的调用会失败。以插值函数为例做个测试,将forward修改一下:

def forward(self, x, scale):
    print(scale)        
    y = F.interpolate(x, scale_factor= 1./scale, mode="bilinear")
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.relu(self.conv3(x))
    x = self.pixel_shuffle(self.conv4(x))
    return x

这个时候运行export会报错,因为插值函数的scale_factor参数不能是一个tensor类型。修改后的正确版本:

def forward(self, x, scale):
    print(scale)        
    y = F.interpolate(x, scale_factor= 1./float(scale), mode="bilinear")
    x = self.relu(self.conv1(x))
    x = self.relu(self.conv2(x))
    x = self.relu(self.conv3(x))
    x = self.pixel_shuffle(self.conv4(x))
    return x

6. 完整代码

在这里 https://github.com/jb2020-super/pytorch-utils/blob/main/to_onnx_ex.ipynb

7. 参考

https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html
https://pytorch.org/docs/stable/onnx.html?highlight=export#torch.onnx.export
https://onnxruntime.ai/docs/get-started/with-python.html

  • 18
    点赞
  • 83
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

superbin

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值