自适应平均池化是一种池化方法,可以在不同大小的输入中自适应地对每个位置进行平均池化。与传统的平均池化方法不同,自适应平均池化不需要指定池化核的大小,而是通过输出的大小来决定池化的大小和步幅。自适应平均池化的输出形状可以由用户指定,因此可以用于任何大小的输入。这使得自适应平均池化在处理不同大小的输入时非常有用。
摘抄自:pytorch torch.nn.AdaptiveAvgPool2d()自适应平均池化函数详解 - 优草派
ADAPTIVEAVGPOOL2D
CLASStorch.nn.AdaptiveAvgPool2d(output_size)[SOURCE]
Applies a 2D adaptive average pooling over an input signal composed of several input planes.
The output is of size H x W, for any input size. The number of output features is equal to the number of input planes.
Parameters
output_size (Union[int, None, Tuple[Optional[int], Optional[int]]]) – the target output size of the image of the form H x W. Can be a tuple (H, W) or a single H for a square image H x H. H and W can be either a int, or None which means the size will be the same as that of the input.
Shape:
- Input: (N, C, H_{in}, W_{in})(N,C,Hin,Win) or (C, H_{in}, W_{in})(C,Hin,Win).
- Output: (N, C, S_{0}, S_{1})(N,C,S0,S1) or (C, S_{0}, S_{1})(C,S0,S1), where S=\text{output\_size}S=output_size.
Examples
# target output size of 5x7
m = nn.AdaptiveAvgPool2d((5, 7))
input = torch.randn(1, 64, 8, 9)
output = m(input)
# target output size of 7x7 (square)
m = nn.AdaptiveAvgPool2d(7)
input = torch.randn(1, 64, 10, 9)
output = m(input)
# target output size of 10x7
m = nn.AdaptiveAvgPool2d((None, 7))
input = torch.randn(1, 64, 10, 9)
output = m(input)
以上是官网的资料,只是简单介绍了一下使用方法,没有更深层的东西
那么一起来探索一下自适应平均池化更为具体的工作原理吧
import torch
x=torch.range(1,40).reshape(1,2,4,5)
pool1=torch.nn.AdaptiveAvgPool2d(3)
pool2=torch.nn.AdaptiveAvgPool2d(6)
y1=pool1(x)
y2=pool2(x)
print(x)
print(y1)
print(y2)
x=:
tensor([[[
[ 1., 2., 3., 4., 5.],
[ 6., 7., 8., 9., 10.],
[11., 12., 13., 14., 15.],
[16., 17., 18., 19., 20.]],
[[21., 22., 23., 24., 25.],
[26., 27., 28., 29., 30.],
[31., 32., 33., 34., 35.],
[36., 37., 38., 39., 40.]]]])
y1:
tensor([[[
[ 4.0000, 5.5000, 7.0000],
[ 9.0000, 10.5000, 12.0000],
[14.0000, 15.5000, 17.0000]],
[[24.0000, 25.5000, 27.0000],
[29.0000, 30.5000, 32.0000],
[34.0000, 35.5000, 37.0000]]]])
第一行:
4.0000, 5.5000, 7.0000
4.0=(1+2+6+7)/4
5.5=(8+3)/2
7.0=(4+5+9+10)/4
第一列:
4.0000,
9.0000,
14.0000,
4.0=(1+2+6+7)/4
9.0=(6+7+11+12)/4
14.0=(11+12+16+17)/4
由此可见:
在行方向5列中kernel.w=2,stride.w=2,但中间多出一列kernel.w=1
在列方向4行中kernel.h=2,stride.h=1,刚好采集3个数据
中间的一列卷积核尺寸是1*2
y2:
tensor([[[
[ 1.0000, 1.5000, 2.5000, 3.5000, 4.5000, 5.0000],
[ 3.5000, 4.0000, 5.0000, 6.0000, 7.0000, 7.5000],
[ 6.0000, 6.5000, 7.5000, 8.5000, 9.5000, 10.0000],
[11.0000, 11.5000, 12.5000, 13.5000, 14.5000, 15.0000],
[13.5000, 14.0000, 15.0000, 16.0000, 17.0000, 17.5000],
[16.0000, 16.5000, 17.5000, 18.5000, 19.5000, 20.0000]],
[[21.0000, 21.5000, 22.5000, 23.5000, 24.5000, 25.0000],
[23.5000, 24.0000, 25.0000, 26.0000, 27.0000, 27.5000],
[26.0000, 26.5000, 27.5000, 28.5000, 29.5000, 30.0000],
[31.0000, 31.5000, 32.5000, 33.5000, 34.5000, 35.0000],
[33.5000, 34.0000, 35.0000, 36.0000, 37.0000, 37.5000],
[36.0000, 36.5000, 37.5000, 38.5000, 39.5000, 40.0000]]]])
第一行:
1.0000, 1.5000, 2.5000, 3.5000, 4.5000, 5.0000
1.0=1.0
1.5=(1+2)/2
2.5=(2+3)/2
3.5=(3+4)/2
4.5=(4+5)/2
5.0=5.0
第一列:
1.0000,
3.5000,
6.0000,
11.0000,
13.5000,
16.0000,
1.0=1.0
3.5=(1+6)/2
6=6
11=11
13.5=(11+16)/2
16=16
由此可见:
向上采样大于一倍小于两倍时,优先保证两头数据和原来一致
行方向5→6只增加一个数据,中间所有值都是左右两个数据平均值
列方向4→6增加两个数据,正好两个数据中间插一个平均值
import torch
x=torch.range(1,40).reshape(1,2,4,5)
pool1=torch.nn.Upsample(scale_factor=4)#默认最邻近插值
pool2=torch.nn.AdaptiveAvgPool2d((16,20))
y1=pool1(x)
y2=pool2(x)
print(x)
print(y1)
print(y2)
y1:
tensor([[[
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.]],
[[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.]]]])
y2:
tensor([[[
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 1., 1., 1., 1., 2., 2., 2., 2., 3., 3., 3., 3., 4., 4., 4., 4., 5., 5., 5., 5.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[ 6., 6., 6., 6., 7., 7., 7., 7., 8., 8., 8., 8., 9., 9., 9., 9., 10., 10., 10., 10.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[11., 11., 11., 11., 12., 12., 12., 12., 13., 13., 13., 13., 14., 14.,14., 14., 15., 15., 15., 15.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.],
[16., 16., 16., 16., 17., 17., 17., 17., 18., 18., 18., 18., 19., 19.,19., 19., 20., 20., 20., 20.]],
[[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[21., 21., 21., 21., 22., 22., 22., 22., 23., 23., 23., 23., 24., 24.,24., 24., 25., 25., 25., 25.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[26., 26., 26., 26., 27., 27., 27., 27., 28., 28., 28., 28., 29., 29.,29., 29., 30., 30., 30., 30.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[31., 31., 31., 31., 32., 32., 32., 32., 33., 33., 33., 33., 34., 34.,34., 34., 35., 35., 35., 35.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.],
[36., 36., 36., 36., 37., 37., 37., 37., 38., 38., 38., 38., 39., 39.,39., 39., 40., 40., 40., 40.]]]])
由此可见:
向上采样整数倍时,和最邻近插值时一样的
根据上述数据相信你已经有了自己的答案,自适应池化不能导出的原因是因为输出特征图比输入大,ONNX没有相应的池化算子,常用的输出特征图尺寸为1*1,2*2这些或者输入和输出一样导出是没问题的,其他还没试过,缩小尺寸的池化应该都没问题
下面链接的方法把paddle换成torch,Layer换成Module可以运行,前向传播cuda利用率50%左右,反向传播几乎是100%,但是速度很慢,一个batch等很久都没结束AdaptiveAvgPool2D 不支持 onnx 导出,自定义一个类代替 AdaptiveAvgPool2D_adaptiveavgpool2d onnx-CSDN博客