paper:Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation
Code:https://github.com/tensorflow/models/tree/master/research/deeplab
1、Atrous Spatial Pyramid Pooling
这篇 DeepLabV3+ 也是语义分割方向相当经典的一篇论文了,里面的 空间金字塔池化 Atrous Spatial Pyramid Pooling(ASPP)更是相当出名的一个模块,在之后也有很多的魔改版本。虽说是18年的一篇论文,但还是在这里整理一下,在之后会慢慢更新近两年的内容。
ASPP 的主要原理是基于以下几点:多尺度信息的重要性: 不同的物体尺度需要不同大小的感受野来捕捉特征。池化操作的局限性: 传统的池化操作(如最大池化)会丢失图像的空间信息,导致无法有效地捕捉多尺度特征。而 ASPP 则通过并行地使用不同膨胀率的空洞卷积,可以在不改变特征图分辨率的情况下,有效地获取不同尺度的上下文信息。
ASPP 的实现方式主要有以下几步:
- 空洞卷积: 空洞卷积是一种特殊的卷积操作,通过在卷积核之间引入空洞,可以扩大感受野,而不增加参数数量和计算复杂度。
- 多尺度膨胀率: ASPP模块并行地使用不同膨胀率的空洞卷积,例如,在本篇论文中使用了 1x1、3x3(rate=6)、3x3(rate=12)、3x3(rate=18)四种膨胀率。
- 池化操作: ASPP模块还包括一个全局平均池化操作,将特征图的空间维度压缩为 1x1,从而获取全局信息。
- 特征融合: 最后,将不同膨胀率的空洞卷积输出和全局平均池化输出进行融合,得到最终的特征图。
ASPP 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
class ASPP(nn.Module):
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
super(ASPP, self).__init__()
self.branch1 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, padding=6 * rate, dilation=6 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch3 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 * rate, dilation=12 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch4 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 3, 1, padding=18 * rate, dilation=18 * rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch5_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
self.branch5_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
self.branch5_relu = nn.ReLU(inplace=True)
self.conv_cat = nn.Sequential(
nn.Conv2d(dim_out * 5, dim_out, 1, 1, padding=0, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
def forward(self, x):
[b, c, row, col] = x.size()
conv1x1 = self.branch1(x)
conv3x3_1 = self.branch2(x)
conv3x3_2 = self.branch3(x)
conv3x3_3 = self.branch4(x)
global_feature = torch.mean(x, 2, True)
global_feature = torch.mean(global_feature, 3, True)
global_feature = self.branch5_conv(global_feature)
global_feature = self.branch5_bn(global_feature)
global_feature = self.branch5_relu(global_feature)
global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, global_feature], dim=1)
result = self.conv_cat(feature_cat)
return result
if __name__ == '__main__':
x = torch.randn(4, 512, 7, 7).cuda()
model = ASPP(512, 512).cuda()
out = model(x)
print(out.shape)