一、SPP-net原理理解
针对卷积神经网络重复运算问题,2015年微软研究院的何恺明等提出一种SPP-Net算法,通过在卷积层和全连接层之间加入空间金字塔池化结构(Spatial Pyramid Pooling)代替R-CNN算法在输入卷积神经网络前对各个候选区域进行剪裁、缩放操作使其图像子块尺寸一致的做法。
利用空间金字塔池化结构有效避免了R-CNN算法对图像区域剪裁、缩放操作导致的图像物体剪裁不全以及形状扭曲等问题,更重要的是解决了卷积神经网络对图像重复特征提取的问题,大大提高了产生候选框的速度,且节省了计算成本。
算法流程:
1)区域提名:用Selective Search从原图中生成2000个左右的候选窗口,这一步和R-CNN一样;
2)区域大小缩放:SPP-net不再做区域大小归一化,而是缩放到min(w, h)=s,即统一长宽的最短边长度,s选自{480,576,688,864,1200}中的一个,选择的标准是使得缩放后的候选框大小与224×224最接近;
3)特征提取:利用SPP-net网络结构提取特征:把整张待检测的图片,输入CNN中进行一次性特征提取,得到feature maps,然后在feature maps中找到各个候选框的区域,再对各个候选框采用金字塔空间池化,提取出固定长度的特征向量;
4)分类与回归:类似R-CNN,利用SVM基于上面的特征训练分类器模型,用边框回归来微调候选框的位置。
创新点:
1)利用空间金字塔池化结构;
2)对整张图片只进行了一次特征提取,加快运算速度。
金字塔池化结构
金字塔池化就是把原来的特征图分别分成4x4=16块,2x2=4块,1x1=1块(不变),总共21块,取每块的最大值作为代表,即每张特征图就有21维的参数,总共卷积出来256个特征图,则送入全连接层的维度就是21256。这样就解决了输入数据大小任意的问题。
这样就把一张任意大小的图片转换成了一个固定大小的21维特征(当然也可以设计其它维数的输出,增加金字塔的层数,或者改变划分网格的大小)。上面的三种不同刻度的划分,每一种刻度我们称之为:金字塔的一层,每一个图片块大小我们称之为:windows size。
SPP-net公式:
设输入数据大小是(c, hin ,win),分别表示通道数,高度,宽度;
池化数量:(n,n);
则:
Kh表示核的高度
Sh表示高度方向的步长
Ph表示高度方向的填充数量,需要乘以2
注意核和步长的计算公式都使用的是ceil(),即向上取整,而padding使用的是floor(),即向下取整。
小结:
SPP-net解决了R-CNN区域提名时crop/warp带来的偏差问题,提出了SPP层,使得输入的候选框可大可小,速度也有了一定的提升。但其他方面依然和R-CNN一样,因而依然存在不少问题,如它的训练要经过多个阶段,特征也要存在磁盘中,这就有了后面的Fast R-CNN。
二、代码实现(Python)
采用PyTorch深度学习框架,构建了一个SPP层,代码如下:
#coding=utf-8
import math
import torch
import torch.nn.functional as F
# 构建SPP层(空间金字塔池化层)
class SPPLayer(torch.nn.Module):
def __init__(self, num_levels, pool_type='max_pool'):
super(SPPLayer, self).__init__()
self.num_levels = num_levels
self.pool_type = pool_type
def forward(self, x):
num, c, h, w = x.size() # num:样本数量 c:通道数 h:高 w:宽
for i in range(self.num_levels):
level = i+1
kernel_size = (math.ceil(h / level), math.ceil(w / level))
stride = (math.ceil(h / level), math.ceil(w / level))
pooling = (math.floor((kernel_size[0]*level-h+1)/2), math.floor((kernel_size[1]*level-w+1)/2))
# 选择池化方式
if self.pool_type == 'max_pool':
tensor = F.max_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
else:
tensor = F.avg_pool2d(x, kernel_size=kernel_size, stride=stride, padding=pooling).view(num, -1)
# 展开、拼接
if (i == 0):
x_flatten = tensor.view(num, -1)
else:
x_flatten = torch.cat((x_flatten, tensor.view(num, -1)), 1)
return x_flatten
因为本人目前是小白阶段学习,主要参考了论文原文及一些网络博客,在此表示感谢。如果有写的不对的地方,欢迎批评指正。