转载自:https://www.cnblogs.com/marsggbo/p/8572846.html
一、为什么需要SPP
首先需要知道为什么会需要SPP。
我们都知道卷积神经网络(CNN)由卷积层和全连接层组成,其中卷积层对于输入数据的大小并没有要求,唯一对数据大小有要求的则是第一个全连接层,因此基本上所有的CNN都要求输入数据固定大小,例如著名的VGG模型则要求输入数据大小是 (224*224) 。
固定输入数据大小有两个问题:
1.很多场景所得到数据并不是固定大小的,例如街景文字基本上其高宽比是不固定的,如下图示红色框出的文字。
2.可能你会说可以对图片进行切割,但是切割的话很可能会丢失到重要信息。
综上,SPP的提出就是为了解决CNN输入图像大小必须固定的问题,从而可以使得输入图像高宽比和大小任意。
二、SPP原理
更加具体的原理可查阅原论文:Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition
上图是原文中给出的示意图,需要从下往上看:
- 首先是输入层(input image),其大小可以是任意的
- 进行卷积运算,到最后一个卷积层(图中是conv5conv5。
四、代码实现(Python)
这里我使用的是PyTorch深度学习框架,构建了一个SPP层,代码如下:
#coding=utf-8
import math
import torch
import torch.nn.functional as F
# 构建SPP层(空间金字塔池化层)
class SPPLayer(torch.nn.Module):
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">__init__</span><span class="hljs-params">(self, num_levels, pool_type=<span class="hljs-string">'max_pool'</span>)</span>:</span>
super(SPPLayer, self).__init__()
self.num_levels = num_levels
self.pool_type = pool_type
<span class="hljs-function"><span class="hljs-keyword">def</span> <span class="hljs-title">forward</span><span class="hljs-params">(self, x)</span>:</span>
num, c, h, w = x.size() <span class="hljs-comment"># num:样本数量 c:通道数 h:高 w:宽</span>
<span class="hljs-keyword">for</span> i <span class="hljs-keyword">in</span> range(self.num_levels):
level = i+<span class="hljs-number">1</span>
kernel_size = (math.ceil(h / level), math.ceil(w / level))
stride = (math.ceil(h / level), math.ceil(w / level))
pooling = (math.floor((kernel_size[<span class="hljs-number">0</span>]*level-h+<span class="hljs-number">1</span>)/<span class="hljs-number">2</span>), math.floor((kernel_size[<span class="hljs-number">1</span>]*level-w+<span class="hljs-number">1</span>)/<span class="hljs-number">2</span>))
<span class="hljs-comment"># 选择池化方式 </span>
<span class="hljs-keyword">if</span> self.pool_type == <span class="hljs-string">'max_pool'</span>:
tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, <span class="hljs-number">-1</span>)
<span class="hljs-keyword">else</span>:
tensor = F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, <span class="hljs-number">-1</span>)
<span class="hljs-comment"># 展开、拼接</span>
<span class="hljs-keyword">if</span> (i == <span class="hljs-number">0</span>):
x_flatten = tensor.view(num, <span class="hljs-number">-1</span>)
<span class="hljs-keyword">else</span>:
x_flatten = torch.cat((x_flatten, tensor.view(num, <span class="hljs-number">-1</span>)), <span class="hljs-number">1</span>)
<span class="hljs-keyword">return</span> x_flatten
上述代码参考: sppnet-pytorch
为防止原作者将代码删除,我已经Fork了,也可以通过如下地址访问代码:
marsggbo/sppnet-pytorch