空间金字塔池化(Spatial Pyramid Pooling, SPP)
在SPPnet和Fast-RCNN中都用到了空间金字塔池化(Spatial Pyramid Pooling, SPP)来提高object detection的效率。SPP本质的目的是为了使得CNN可以接受任意尺寸的输入图片,从而避免了图像预处理中要将图片resize到统一尺寸这个限制。
SPPnet论文中给出了如下的图示:
给定一张任意尺寸的输入图片,首先经过一个网络,得到某卷积层的输出(即图中Conv5的输出);之后经过三个不同的池化层,分别得到 16 × 25616 × 25616 × 256 16×25616×256 16\times 256 16×25616×25616×2561×256维的向量,将这些向量拼接起来,即得到固定长度的特征向量(fixed-length representation,这里长度为5376)。无论输入的图片是任何尺寸的,我们都可以得到长度为5376的特征向量。
卷积层实际上并不受限于输入图像的尺寸,无论给定什么尺寸的图,卷积都是可以进行的。然而后面的全连接层就不行了,必须接受固定尺寸的输入。所以通过上述的金字塔池化,就可以解决全连接层的输入问题。从而使得网络可以接受任意尺寸的输入图片。
单看论文还是不好理解,到底是如何操作的。我们从代码的角度来看下是如何实现的,最后再证明下这样确实是可行的。
import torch.nn as nn
import torch
def spatial_pyramid_pool(previous_conv, num_sample, previous_conv_size, out_pool_size):
‘’’
previous_conv: a tensor vector of previous convolution layer
num_sample: an int number of image in the batch
previous_conv_size: an int vector [height, width] of the matrix features size of previous convolution layer
out_pool_size: a int vector of expected output size of max pooling layer
returns: a tensor vector with shape [1 x n] is the concentration of multi-level pooling
'''
print(previous_conv.size())
for i in range(len(out_pool_size)):
h_wid = int(math.ceil(previous_conv_size[0] / out_pool_size[i]))
w_wid = int(math.ceil(previous_conv_size[1] / out_pool_size[i]))
h_pad = int((h_wid*out_pool_size[i] - previous_conv_size[0] + 1)/2)
w_pad = int((w_wid*out_pool_size[i] - previous_conv_size[1] + 1)/2)
print (h_wid, w_wid,h_pad,w_pad)
maxpool = nn.MaxPool2d((h_wid, w_wid), stride=(h_wid, w_wid), padding=(h_pad, w_pad))
x = maxpool(previous_conv)
if(i == 0):
spp = x.view(num_sample,-1)
else:
spp = torch.cat((spp,x.view(num_sample,-1)), 1)
return spp
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
假如这里输入的feature map即previous_conv的channel为512,out_pool_size=[4,2,1]。则经过这三个池化层,我们分别得到 4 × 4 × 5124 × 4 × 5124 × 4 × 512 4×4×5124×4×512 4\times 4\times 512 4×4×5124×4×5124×4×512n值也是一致的,因此输出尺寸是一致的。因此可以适应于不同尺寸的图片输入。
</div><div><div></div></div>
<link href="https://csdnimg.cn/release/phoenix/mdeditor/markdown_views-60ecaf1f42.css" rel="stylesheet">
</div>