ASPP是基于空洞卷积(Dilatd/Atrous Convolution)和SPP(空间金字塔池化)的。用空洞卷积代替了单纯的adaptivepooling. ASPP对所给定的输入以不同采样率的空洞卷积并行采样,相当于以多个比例捕捉图像的上下文。
ASPP实际上是空间金字塔池的一个版本,其中的概念已经在SPPNet中描述。在ASPP中,在输入特征映射中应用不同速率的并行空洞卷积,并融合在一起。由于同一类的物体在图像中可能有不同的比例,ASPP有助于考虑不同的物体比例,这可以提高准确性。
DeepLab v2中就有用到ASPP模块
这里设计了几种不同采样率的空洞卷积来捕捉多尺度信息,但我们要明白采样率(dilation rate)并不是越大越好,因为采样率太大,会导致滤波器有的会跑到padding上,产生无意义的权重,因此要选择合适的采样率。
Pytorch实现
import torch from torch import nn import torch.nn.functional as F class ASPP(nn.Module): def __init__(self, num_classes): super(ASPP, self).__init__() self.conv_1x1_1 = nn.Conv2d(2048, 256, kernel_size=1) self.bn_conv_1x1_1 = nn.BatchNorm2d(256) self.conv_3x3_6 = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=6, dilation=6) self.bn_conv_3x3_6 = nn.BatchNorm2d(256) self.conv_3x3_12 = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=12, dilation=12) self.bn_conv_3x3_12 = nn.BatchNorm2d(256) self.conv_3x3_18 = nn.Conv2d(2048, 256, kernel_size=3, stride=1, padding=18, dilation=18) self.bn_conv_3x3_18 = nn.BatchNorm2d(256) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_1x1_2 = nn.Conv2d(2048, 256, kernel_size=1) self.bn_conv_1x1_2 = nn.BatchNorm2d(256) self.conv_1x1_3 = nn.Conv2d(1280, 256, kernel_size=1) # (1280 = 5*256) self.bn_conv_1x1_3 = nn.BatchNorm2d(256) self.conv_1x1_4 = nn.Conv2d(256, num_classes, kernel_size=1) def forward(self, feature_map): # (feature_map has shape (batch_size, 2048, h/8, w/8)) feature_map_h = feature_map.size()[2] # (h/8) feature_map_w = feature_map.size()[3] # (w/8) out_1x1 = F.relu(self.bn_conv_1x1_1(self.conv_1x1_1(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 E out_3x3_1 = F.relu(self.bn_conv_3x3_6(self.conv_3x3_6(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 D out_3x3_2 = F.relu(self.bn_conv_3x3_12(self.conv_3x3_12(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 C out_3x3_3 = F.relu(self.bn_conv_3x3_18(self.conv_3x3_18(feature_map))) # (shape: (batch_size, 256, h/8, w/8)) 对应图中 B #out_1x1,out_3x3_1,out_3x3_2,out_3x3_3 的shape都一样 out_img = self.avg_pool(feature_map) # (shape: (batch_size, 512, 1, 1))对应图中 ImagePooling out_img = F.relu(self.bn_conv_1x1_2(self.conv_1x1_2(out_img))) # (shape: (batch_size, 256, 1, 1)) out_img = F.upsample(out_img, size=(feature_map_h, feature_map_w), mode="bilinear") # (shape: (batch_size, 256, h/8, w/8))对应图中 A out = torch.cat([out_1x1, out_3x3_1, out_3x3_2, out_3x3_3, out_img], dim=1) # (shape: (batch_size, 1280, h/8, w/8)) cat对应图中 F out = F.relu(self.bn_conv_1x1_3(self.conv_1x1_3(out))) # (shape: (batch_size, 256, h/8, w/8)) bn_conv_1x1_3对应图中 H out 对应图中I out = self.conv_1x1_4(out) # (shape: (batch_size, num_classes, h/8, w/8))out 对应图中Upsample by 4 return out if __name__ == '__main__': x = torch.rand(4,2048,28,28) #[b,c,h,w] aspp = ASPP(num_classes=10) out = aspp(x) #[b,num_class,h,w]