论文地址:Strip Pooling
论文代码:github
1 简介
基于空间池化能够有效的提取到long-range的上下文信息这一假设,稳重主要有三个贡献:
- 设计了一种使用非对称卷积 1 × N 1\times N 1×N和 N × 1 N\times 1 N×1新的空间pooling结构;
- 使用strip pooling设计了一个新的空间pooling模块和Mixed Pooling Module;
- 基于strp pooling和mixed pooling搭建了spnet。
2 思路
语义分割自FCNs和U-Net之后对网络提取空间特征的分为两个方向:一个是引入no-local的self-attention机制来增强backbone提取到的特征,缺点是非常耗费内存;另一个是使用诸如空间卷积,PPM之类的模块来拓宽感受野,提取空间信息。作者认为类似PPM之类的结构限制了它们在捕获现实场景中广泛存在的各向异性上下文时的灵活性。
因此尝试将不同维度的空间特征分开进行各自提取。
3 网络结构
3.1 Strip Pooling
标准的空间平均池化:
y
i
o
,
j
o
=
1
h
×
w
∑
0
≤
i
<
h
∑
0
≤
j
<
w
x
i
o
×
h
+
i
,
j
o
×
w
+
j
y_{{i_o},{j_o}}=\frac{1}{h\times w}\sum_{0\le i<h}\sum_{0\le j<w}{x_{{i_o}\times h + i, j_o\times w + j}}
yio,jo=h×w10≤i<h∑0≤j<w∑xio×h+i,jo×w+j
其中不同符号的含义如下:
- x ∈ R H × W x\in \mathcal{R}^{H\times W} x∈RH×W是输入;
- h × w h\times w h×w为池化的窗口大小;
- H o = H h , W o = W w H_o=\frac{H}{h},W_o=\frac{W}{w} Ho=hH,Wo=wW为输出尺寸;
- 0 ≤ i o < H o , 0 ≤ j o < W o 0\le i_o < H_o, 0\le j_o < W_o 0≤io<Ho,0≤jo<Wo;
strip pooling中使用非对称卷积,因此其输出如下:
- 水平方向: y i h = 1 W ∑ 0 ≤ j < W x i , j y_i^h=\frac{1}{W}\sum_{0\le j <W}x_{i,j} yih=W1∑0≤j<Wxi,j
- 垂直方向: y j v = 1 H ∑ 0 ≤ i < H x i , j y_j^v=\frac{1}{H}\sum_{0\le i <H}x_{i,j} yjv=H1∑0≤i<Hxi,j
其中不同符号的含义如下:
- y h ∈ R H y^h\in \mathcal{R}^H yh∈RH;
- y v ∈ R W y^v\in \mathcal{R}^W yv∈RW;
从公式上看的话,作者提到的long-range本质上是将所有单维度上的所有值进行pooling达到了单维度的全局感受野。
3.2 Strip Pooling Module
Strip Pooling分为两个分支一分支使用strip pooling提取水平方向的feature得到
y
h
∈
R
C
×
H
y^h\in \mathcal{R}^{C\times H}
yh∈RC×H,另一个分支使用strip pooling提取垂直方向的feature得到
y
v
∈
R
C
×
W
y^v\in \mathcal{R}^{C\times W}
yv∈RC×W,然后通过来两两相加得到
y
∈
R
C
×
H
×
W
y\in \mathcal{R}^{C\times H\times W}
y∈RC×H×W。
y
c
,
i
,
j
=
y
c
,
i
h
+
y
c
,
j
v
y_{c,i,j}=y^h_{c,i}+y^v_{c,j}
yc,i,j=yc,ih+yc,jv
之后将
y
y
y经过通道扩张和sigmoid处理,再和原输入相乘:
z
=
S
c
a
l
e
(
x
,
σ
(
f
(
y
)
)
)
z=Scale(x,\sigma(f(y)))
z=Scale(x,σ(f(y)))
- Scale为逐点相乘;
- σ \sigma σ为sigmoid;
- f f f为 1 × 1 1\times 1 1×1卷积。
下面的代码是我根据论文中关于StropPooling的结构复现的,论文源码中没有StripPooling的代码。
class StripPooling(nn.Module):
def __init__(self, in_channel, out_channel):
super(StripPooling, self).__init__()
inter_channel = in_channel / 4
self.conv11_1 = nn.Sequential(conv1x1(in_channel, inter_channel), nn.BatchNorm2d(inter_channel), nn.ReLU(True))
self.conv11_2 = nn.Sequential(conv1x1(in_channel, inter_channel), nn.BatchNorm2d(inter_channel), nn.ReLU(True))
self.v_pool = nn.AdaptiveAvgPool2d((None, 1))
self.h_pool = nn.AdaptiveAvgPool2d((1, None))
self.conv11 = nn.Sequential(conv1x1(inter_channel, out_channel), nn.BatchNorm2d(out_channel), nn.ReLU(True))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, h, w = x.shape
hx = self.conv11_1(x)
vx = self.conv11_2(x)
hx = self.h_pool(hx)
vx = self.v_pool(vx)
#extend
hx = F.upsample_bilinear(hx, (h, w))
vx = F.upsample_bilinear(vx, (h, w))
fusion = hx + vx
fusion = self.conv11(fusion)
fusion = self.sigmoid(fusion)
x = x.mul(fusion)
return F.relu_(x)
3.3 Mixed Pooling Module
PPM和ASPP基本证明了类似的结构能够有效的提取相关场景信息。作者使用strip pooling设计了一种类似PPM结构的Mixed Pooling Module(MPM)。
MPM分为两个子模块分别提取short-range和long-range的特征:long-range分支中使用strip pooling提取相关特征;short-range部分是一个小型的ppm。
MMP这部分代码论文中给的名称是StripPooling,但是从代码上看结构完全是MMP。
class MMP(nn.Module):
"""
Reference:
"""
def __init__(self, in_channels, pool_size, norm_layer, up_kwargs):
super(MMP, self).__init__()
self.pool1 = nn.AdaptiveAvgPool2d(pool_size[0])
self.pool2 = nn.AdaptiveAvgPool2d(pool_size[1])
self.pool3 = nn.AdaptiveAvgPool2d((1, None))
self.pool4 = nn.AdaptiveAvgPool2d((None, 1))
inter_channels = int(in_channels/4)
self.conv1_1 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), norm_layer(inter_channels), nn.ReLU(True))
self.conv1_2 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 1, bias=False), norm_layer(inter_channels), nn.ReLU(True))
self.conv2_0 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), norm_layer(inter_channels))
self.conv2_1 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), norm_layer(inter_channels))
self.conv2_2 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), norm_layer(inter_channels))
self.conv2_3 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (1, 3), 1, (0, 1), bias=False), norm_layer(inter_channels))
self.conv2_4 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, (3, 1), 1, (1, 0), bias=False), norm_layer(inter_channels))
self.conv2_5 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), norm_layer(inter_channels), nn.ReLU(True))
self.conv2_6 = nn.Sequential(nn.Conv2d(inter_channels, inter_channels, 3, 1, 1, bias=False), norm_layer(inter_channels), nn.ReLU(True))
self.conv3 = nn.Sequential(nn.Conv2d(inter_channels*2, in_channels, 1, bias=False), norm_layer(in_channels))
# bilinear interpolate options
self._up_kwargs = up_kwargs
def forward(self, x):
_, _, h, w = x.size()
x1 = self.conv1_1(x)
x2 = self.conv1_2(x)
x2_1 = self.conv2_0(x1)
x2_2 = F.interpolate(self.conv2_1(self.pool1(x1)), (h, w), **self._up_kwargs)
x2_3 = F.interpolate(self.conv2_2(self.pool2(x1)), (h, w), **self._up_kwargs)
x2_4 = F.interpolate(self.conv2_3(self.pool3(x2)), (h, w), **self._up_kwargs)
x2_5 = F.interpolate(self.conv2_4(self.pool4(x2)), (h, w), **self._up_kwargs)
x1 = self.conv2_5(F.relu_(x2_1 + x2_2 + x2_3))
x2 = self.conv2_6(F.relu_(x2_5 + x2_4))
out = self.conv3(torch.cat([x1, x2], dim=1))
return F.relu_(x + out)
4 实验结果
PPM对比实验: