A Single Stream Network for Robust and Real-time RGB-D Salient Object Detection
现有问题:
大多数的RGB-D显著目标检测方法集中在RGB流和Depth流之间的跨模型融合,并没有深入的探索深度图本身的效果。
作者方案:
设计一个单流网络直接使用深度图指导RGB和Depth之间的早期融合和中间融合,这就可以节省深度流的特征编码器,实现了一个轻量且实时的模型。
1、
(1)为克服由模型间较大差异造成的不兼容问题
(2)大多数的方法使用双流结构分别从RGB和Depth中提取特征,会极大的增加网络的参数量;
(3)还有就是现有RGB-D数据集的小尺度,RGB和Depth之间较大差异的问题,如果把RGB和Depth通道连接起来然后送入网络就会使得深度网络难以训练。
因此作者建立了一个单流编码器实现早期融合,来充分利用ImageNet上预训练的骨干网模型提取丰富且有区别的特征。
2、
深度图可以描绘不同位置的对比线索,这就为前景背景的分割提供重要指导。因此,作者在编码器和解码器之间引入一个空间滤波机制,明确利用深度图指导双注意计算,从而促进前景背景解码分支中特征识别。
设计了一个新型深度增强的双注意模块(DEDA)来提供具有空间滤波特征的前景背景分支,能够使解码器最佳的实施中间融合。
这个模块利用掩码指导策略mask-guided strategy和深度指导策略 depth-guided strategy来过滤深度和外观之间的相互干扰,这就增强了前景背景之间的整体对比度。
3、
目标的大小是多样的,为了准确定位目标就需要利用多尺度的上下文信息。
因此提出一个金字塔参与特征提取模型(PAFE)来准确的定位不同尺度的目标。这个模型可以描绘特征图中任意两个位置之间的空间依赖性。
网络整体结构图:
the VGG-16 (E1 ∼ E5)
five transition layers (T1 ∼ T5) 使用3x3卷积操作
five saliency layers (S1 ∼ S5)
five background layers (B1 ∼ B5)
the depth-enhanced dual attention module (DEDA)
The final prediction is generated by using residual connections to fuse the outputs from S1 and B1.
一、Single Stream Encoder Network
单流编码器网络使用了一个全卷积网络(FCN)结构。FCN将传统CNN中的全连接层转化成卷积层。
采用VGG-16作为骨干网,丢掉了所有的全连接层,移除了最后的池化层。
代码如下:
################################vgg16#######################################
feats = list(models.vgg16_bn(pretrained=True).features.children())
self.conv0 = nn.Conv2d(4, 64, kernel_size=3, padding=1)
self.conv1 = nn.Sequential(*feats[1:6])
self.conv2 = nn.Sequential(*feats[6:13])
self.conv3 = nn.Sequential(*feats[13:23])
self.conv4 = nn.Sequential(*feats[23:33])
self.conv5 = nn.Sequential(*feats[33:43])
vgg.features就是提取vgg模型的features网络层部分。children返回的是结构中的每一层网络即 Sequential中的每一层
打开下载好的vgg16_bn可以看具体网络层结构如下:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(9): ReLU(inplace=True)
(10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(12): ReLU(inplace=True)
(13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(14): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(16): ReLU(inplace=True)
(17): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(18): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(19): ReLU(inplace=True)
(20): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(21): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(22): ReLU(inplace=True)
(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(24): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(25): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(26): ReLU(inplace=True)
(27): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(28): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(29): ReLU(inplace=True)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(32): ReLU(inplace=True)
(33): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(36): ReLU(inplace=True)
(37): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(38): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(39): ReLU(inplace=True)
(40): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(41): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(42): ReLU(inplace=True)
(43): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace=True)
(2): Dropout(p=0.5, inplace=False)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace=True)
(5): Dropout(p=0.5, inplace=False)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
二、the depth-enhanced dual attention module (DEDA)
采用掩码监督和深度指导机制来过滤掉错误的信息。
Am:mask-guided attention Ti当前传输层 Si+1前一个解码器块 D 深度图
结果Am存在两个问题:
(1)一些背景区域被误分类为显著区域,因此引入了深度信息细化Am:
上式为显著性分支的深度增强注意,可以为Am中误判区域提供额外的对比指导并且可以维持背景前景之间高对比度,就增强了Am。
(2)一些显著区域被误标记为背景,因此设计了背景分支的深度增强注意如下:
代码实现如下:
################################DAM for Saliency branch&Background branch#######################################
#mask-guided attention Am
dem1_attention = F.sigmoid(self.fuse_1(dem1+F.upsample(depth, size=dem1.size()[2:], mode='bilinear'))) #Am=sigmoid(conv(T5+D))
#the depth-enhanced attention of the saliency branch Asd
output1 = self.output1(dem1*(dem1_attention*(F.upsample(depth, size=dem1.size()[2:], mode='bilinear')+dem1_attention)))
#the depth-enhanced attention of the background branch Abd
output1_rev = self.output1_rev(dem1*((1-dem1_attention)*(F.upsample(depth, size=dem1.size()[2:], mode='bilinear')+(1-dem1_attention))))
dem2_attention = F.sigmoid(self.fuse_2(dem2+F.upsample(output1, size=dem2.size()[2:], mode='bilinear')+F.upsample(depth, size=dem2.size()[2:], mode='bilinear')))
output2 = self.output2(F.upsample(output1, size=dem2.size()[2:], mode='bilinear')+dem2*(dem2_attention*(F.upsample(depth, size=dem2.size()[2:], mode='bilinear')+dem2_attention)))
output2_rev = self.output2_rev(F.upsample(output1_rev, size=dem2.size()[2:], mode='bilinear')+dem2*((1-dem2_attention)*(F.upsample(depth, size=dem2.size()[2:], mode='bilinear')+(1-dem2_attention))))
dem3_attention = F.sigmoid(self.fuse_3(dem3+F.upsample(output2, size=dem3.size()[2:], mode='bilinear')+F.upsample(depth, size=dem3.size()[2:], mode='bilinear')))
output3 = self.output3(F.upsample(output2, size=dem3.size()[2:], mode='bilinear')+dem3*(dem3_attention*(F.upsample(depth, size=dem3.size()[2:], mode='bilinear')+dem3_attention)))
output3_rev = self.output3_rev(F.upsample(output2_rev, size=dem3.size()[2:], mode='bilinear')+dem3*((1-dem3_attention)*(F.upsample(depth, size=dem3.size()[2:], mode='bilinear')+(1-dem3_attention))))
dem4_attention = F.sigmoid(self.fuse_4(dem4+F.upsample(output3, size=dem4.size()[2:], mode='bilinear')+F.upsample(depth, size=dem4.size()[2:], mode='bilinear')))
output4 = self.output4(F.upsample(output3, size=dem4.size()[2:], mode='bilinear')+dem4*(dem4_attention*(F.upsample(depth, size=dem4.size()[2:], mode='bilinear')+dem4_attention)))
output4_rev = self.output4_rev(F.upsample(output3_rev, size=dem4.size()[2:], mode='bilinear')+dem4*((1-dem4_attention)*(F.upsample(depth, size=dem4.size()[2:], mode='bilinear')+(1-dem4_attention))))
dem5_attention = F.sigmoid(self.fuse_5(dem5+F.upsample(output4, size=dem5.size()[2:], mode='bilinear')+F.upsample(depth, size=dem5.size()[2:], mode='bilinear')))
output5 = self.output5(F.upsample(output4, size=dem5.size()[2:], mode='bilinear')+dem5*(dem5_attention*(F.upsample(depth, size=dem5.size()[2:], mode='bilinear')+dem5_attention)))
output5_rev = self.output5_rev(F.upsample(output4_rev, size=dem5.size()[2:], mode='bilinear')+dem5*((1-dem5_attention)*(F.upsample(depth, size=dem5.size()[2:], mode='bilinear')+(1-dem5_attention))))
三、 Pyramidally Attended Feature Extraction(PAFE)
注意图A计算如下:
输出计算如下:
实现代码如下:
###ECCV2020 A Single Stream Network for Robust and Real-time RGB-D Salient Object Detection
class PAFEM(nn.Module):
def __init__(self, dim,in_dim):
super(PAFEM, self).__init__()
self.down_conv = nn.Sequential(nn.Conv2d(dim,in_dim , 3,padding=1),nn.BatchNorm2d(in_dim),
nn.PReLU())
down_dim = in_dim // 2
self.conv1 = nn.Sequential(
nn.Conv2d(in_dim, down_dim, kernel_size=1), nn.BatchNorm2d(down_dim), nn.PReLU()
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_dim, down_dim, kernel_size=3, dilation=2, padding=2), nn.BatchNorm2d(down_dim), nn.PReLU()
)
# 分别得到B,C,D,这里对B和C的输出通道数进行了压缩8倍
self.query_conv2 = Conv2d(in_channels=down_dim, out_channels=down_dim//8, kernel_size=1)
self.key_conv2 = Conv2d(in_channels=down_dim, out_channels=down_dim//8, kernel_size=1)
self.value_conv2 = Conv2d(in_channels=down_dim, out_channels=down_dim, kernel_size=1)
# gamma 对应公式中的alpha
self.gamma2 = Parameter(torch.zeros(1))
self.conv3 = nn.Sequential(
nn.Conv2d(in_dim, down_dim, kernel_size=3, dilation=4, padding=4), nn.BatchNorm2d(down_dim), nn.PReLU()
)
self.query_conv3 = Conv2d(in_channels=down_dim, out_channels=down_dim//8, kernel_size=1)
self.key_conv3 = Conv2d(in_channels=down_dim, out_channels=down_dim//8, kernel_size=1)
self.value_conv3 = Conv2d(in_channels=down_dim, out_channels=down_dim, kernel_size=1)
self.gamma3 = Parameter(torch.zeros(1))
self.conv4 = nn.Sequential(
nn.Conv2d(in_dim, down_dim, kernel_size=3, dilation=6, padding=6), nn.BatchNorm2d(down_dim), nn.PReLU()
)
self.query_conv4 = Conv2d(in_channels=down_dim, out_channels=down_dim//8, kernel_size=1)
self.key_conv4 = Conv2d(in_channels=down_dim, out_channels=down_dim//8, kernel_size=1)
self.value_conv4 = Conv2d(in_channels=down_dim, out_channels=down_dim, kernel_size=1)
self.gamma4 = Parameter(torch.zeros(1))
self.conv5 = nn.Sequential(
nn.Conv2d(in_dim, down_dim, kernel_size=1),nn.BatchNorm2d(down_dim), nn.PReLU() #如果batch=1 ,进行batchnorm会有问题
)
self.fuse = nn.Sequential(
nn.Conv2d(5 * down_dim, in_dim, kernel_size=1), nn.BatchNorm2d(in_dim), nn.PReLU()
)
self.softmax = Softmax(dim=-1)
def forward(self, x):
x = self.down_conv(x)
conv1 = self.conv1(x)
conv2 = self.conv2(x)
m_batchsize, C, height, width = conv2.size()
# B(N,C,H,W) -> (N,C,HW) -> (N,HW,C)
proj_query2 = self.query_conv2(conv2).view(m_batchsize, -1, width * height).permute(0, 2, 1)
# C(N,C,H,W) -> (N,C,HW)
proj_key2 = self.key_conv2(conv2).view(m_batchsize, -1, width * height)
#torch.bmm批矩阵乘操作。 BxC -> (N,HW,HW)
energy2 = torch.bmm(proj_query2, proj_key2)
#S = softmax(BxC)后得出注意图 -> (N,HW,HW)
attention2 = self.softmax(energy2)
# D -> (N,C,HW)
proj_value2 = self.value_conv2(conv2).view(m_batchsize, -1, width * height)
# DxS(转置) -> (N,C,HW)
out2 = torch.bmm(proj_value2, attention2.permute(0, 2, 1))
#reshape(DxS) ->(N,C,H,W)
out2 = out2.view(m_batchsize, C, height, width)
out2 = self.gamma2* out2 + conv2
conv3 = self.conv3(x)
m_batchsize, C, height, width = conv3.size()
proj_query3 = self.query_conv3(conv3).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key3 = self.key_conv3(conv3).view(m_batchsize, -1, width * height)
energy3 = torch.bmm(proj_query3, proj_key3)
attention3 = self.softmax(energy3)
proj_value3 = self.value_conv3(conv3).view(m_batchsize, -1, width * height)
out3 = torch.bmm(proj_value3, attention3.permute(0, 2, 1))
out3 = out3.view(m_batchsize, C, height, width)
out3 = self.gamma3 * out3 + conv3
conv4 = self.conv4(x)
m_batchsize, C, height, width = conv4.size()
proj_query4 = self.query_conv4(conv4).view(m_batchsize, -1, width * height).permute(0, 2, 1)
proj_key4 = self.key_conv4(conv4).view(m_batchsize, -1, width * height)
energy4 = torch.bmm(proj_query4, proj_key4)
attention4 = self.softmax(energy4)
proj_value4 = self.value_conv4(conv4).view(m_batchsize, -1, width * height)
out4 = torch.bmm(proj_value4, attention4.permute(0, 2, 1))
out4 = out4.view(m_batchsize, C, height, width)
out4 = self.gamma4 * out4 + conv4
conv5 = F.upsample(self.conv5(F.adaptive_avg_pool2d(x, 1)), size=x.size()[2:], mode='bilinear') # 如果batch设为1,这里就会有问题。
return self.fuse(torch.cat((conv1, out2, out3,out4, conv5), 1))
这部分应用了位置注意模块如下:
源码大致一样。
此模块代码可以参以下两篇博文:
添加链接描述
添加链接描述
实施细节:
epoch:40
mini-batch size:4
optimizer:SGD momentum=0.9 weight decay=0.0005
lr=0.001 "poly"policy power=0.9
Loss: binary cross-entropy loss