Pytorch自适应平均池化工作原理(AdaptiveAvgPool2d无法导出ONNX,寻找替换方法)

​自适应平均池化是一种池化方法,可以在不同大小的输入中自适应地对每个位置进行平均池化。与传统的平均池化方法不同,自适应平均池化不需要指定池化核的大小,而是通过输出的大小来决定池化的大小和步幅。自适应平均池化的输出形状可以由用户指定,因此可以用于任何大小的输入。这使得自适应平均池化在处理不同大小的输入时非常有用。

摘抄自:pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解 - 优草派

ADAPTIVEAVGPOOL2D

CLASStorch.nn.AdaptiveAvgPool2d(output_size)[SOURCE]
Applies a 2D adaptive average pooling over an input signal composed of several input planes.
The output is of size H x W, for any input size. The number of output features is equal to the number of input planes.

Parameters

output_size (Union[int, None, Tuple[Optional[int], Optional[int]]]) – the target output size of the image of the form H x W. Can be a tuple (H, W) or a single H for a square image H x H. H and W can be either a int, or None which means the size will be the same as that of the input.

Shape:

  • Input: (N, C, H_{in}, W_{in})(N,C,Hin​,Win​) or (C, H_{in}, W_{in})(C,Hin​,Win​).
  • Output: (N, C, S_{0}, S_{1})(N,C,S0​,S1​) or (C, S_{0}, S_{1})(C,S0​,S1​), where S=\text{output\_size}S=output_size.

Examples

# target output size of 5x7
m = nn.AdaptiveAvgPool2d((5, 7))
input = torch.randn(1, 64, 8, 9)
output = m(input)
# target output size of 7x7 (square)
m = nn.AdaptiveAvgPool2d(7)
input = torch.randn(1, 64, 10, 9)
output = m(input)
# target output size of 10x7
m = nn.AdaptiveAvgPool2d((None, 7))
input = torch.randn(1, 64, 10, 9)
output = m(input)


以上是官网的资料,只是简单介绍了一下使用方法,没有更深层的东西
那么一起来探索一下自适应平均池化更为具体的工作原理吧

import torch
x=torch.range(1,40).reshape(1,2,4,5)
pool1=torch.nn.AdaptiveAvgPool2d(3)
pool2=torch.nn.AdaptiveAvgPool2d(6)
y1=pool1(x)
y2=pool2(x)
print(x)
print(y1)
print(y2)

x=:
tensor([[[
          [  1.,   2.,   3.,   4.,   5.],
          [  6.,   7.,   8.,   9., 10.],
          [11., 12., 13., 14., 15.],
          [16., 17., 18., 19., 20.]],

         [[21., 22., 23., 24., 25.],
          [26., 27., 28., 29., 30.],
          [31., 32., 33., 34., 35.],
          [36., 37., 38., 39., 40.]]]])

y1:
tensor([[[
          [ 4.0000,  5.5000,  7.0000],
          [ 9.0000, 10.5000, 12.0000],
          [14.0000, 15.5000, 17.0000]],

         [[24.0000, 25.5000, 27.0000],
          [29.0000, 30.5000, 32.0000],
          [34.0000, 35.5000, 37.0000]]]])

第一行:
4.0000,  5.5000,  7.0000
4.0=(1+2+6+7)/4
5.5=(8+3)/2
7.0=(4+5+9+10)/4
第一列:
4.0000,
9.0000,
14.0000,
4.0=(1+2+6+7)/4
9.0=(6+7+11+12)/4
14.0=(11+12+16+17)/4
由此可见:
在行方向5列中kernel.w=2,stride.w=2,但中间多出一列kernel.w=1
在列方向4行中kernel.h=2,stride.h=1,刚好采集3个数据
中间的一列卷积核尺寸是1*2

y2:
tensor([[[
          [ 1.0000,  1.5000,  2.5000,  3.5000,  4.5000,  5.0000],
          [ 3.5000,  4.0000,  5.0000,  6.0000,  7.0000,  7.5000],
          [ 6.0000,  6.5000,  7.5000,  8.5000,  9.5000, 10.0000],
          [11.0000, 11.5000, 12.5000, 13.5000, 14.5000, 15.0000],
          [13.5000, 14.0000, 15.0000, 16.0000, 17.0000, 17.5000],
          [16.0000, 16.5000, 17.5000, 18.5000, 19.5000, 20.0000]],

         [[21.0000, 21.5000, 22.5000, 23.5000, 24.5000, 25.0000],
          [23.5000, 24.0000, 25.0000, 26.0000, 27.0000, 27.5000],
          [26.0000, 26.5000, 27.5000, 28.5000, 29.5000, 30.0000],
          [31.0000, 31.5000, 32.5000, 33.5000, 34.5000, 35.0000],
          [33.5000, 34.0000, 35.0000, 36.0000, 37.0000, 37.5000],
          [36.0000, 36.5000, 37.5000, 38.5000, 39.5000, 40.0000]]]])
​第一行:
1.0000,  1.5000,  2.5000,  3.5000,  4.5000,  5.0000

1.0=1.0
1.5=(1+2)/2
2.5=(2+3)/2
3.5=(3+4)/2
4.5=(4+5)/2
5.0=5.0
第一列:
 1.0000,
 3.5000,
 6.0000,
11.0000,
13.5000,
16.0000,
1.0=1.0
3.5=(1+6)/2
6=6
11=11
13.5=(11+16)/2
16=16
由此可见:
向上采样大于一倍小于两倍时,优先保证两头数据和原来一致
行方向5→6只增加一个数据,中间所有值都是左右两个数据平均值
列方向4→6增加两个数据,正好两个数据中间插一个平均值

import torch
x=torch.range(1,40).reshape(1,2,4,5)
pool1=torch.nn.Upsample(scale_factor=4)#默认最邻近插值
pool2=torch.nn.AdaptiveAvgPool2d((16,20))
y1=pool1(x)
y2=pool2(x)
print(x)
print(y1)
print(y2)

y1:
tensor([[[
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.]],

         [[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.]]]])

y2:
tensor([[[
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 1.,   1.,   1.,   1.,   2.,   2.,   2.,   2.,   3.,   3.,   3.,   3.,   4.,   4.,  4.,   4.,   5.,   5.,   5.,   5.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [ 6.,   6.,   6.,   6.,   7.,   7.,   7.,   7.,   8.,   8.,   8.,   8.,   9.,   9.,  9.,   9., 10., 10., 10., 10.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
          [16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.]],

         [[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
          [36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.]]]])
由此可见:
向上采样整数倍时,和最邻近插值时一样的

根据上述数据相信你已经有了自己的答案,自适应池化不能导出的原因是因为输出特征图比输入大,ONNX没有相应的池化算子,常用的输出特征图尺寸为1*1,2*2这些或者输入和输出一样导出是没问题的,其他还没试过,缩小尺寸的池化应该都没问题

下面链接的方法把paddle换成torch,Layer换成Module可以运行,前向传播cuda利用率50%左右,反向传播几乎是100%,但是速度很慢,一个batch等很久都没结束AdaptiveAvgPool2D 不支持 onnx 导出,自定义一个类代替 AdaptiveAvgPool2D_adaptiveavgpool2d onnx-CSDN博客

  • 26
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值