CVPR2019发布,在看参考文献时发现的一篇显著性目标检测文章
有些博客讲的还是蛮详细的,放在下面了
然后自己也画了一下网络框图加深理解
目录
CVPR2019发布,在看参考文献时发现的一篇显著性目标检测文章
4.Holistic Attention Module整体注意力模块
1.整体框图
2.RFB模块
class RFB(nn.Module):
# RFB-like multi-scale module
def __init__(self, in_channel, out_channel):
super(RFB, self).__init__()
self.relu = nn.ReLU(True)
self.branch0 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
)
self.branch1 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
)
self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
self.conv_res = BasicConv2d(in_channel, out_channel, 1)
def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))
x = self.relu(x_cat + self.conv_res(x))
return x
3.aggregation模块
class aggregation(nn.Module):
# dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on.
# used after MSF
def __init__(self, channel):
super(aggregation, self).__init__()
self.relu = nn.ReLU(True)
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
self.conv4 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
self.conv5 = nn.Conv2d(3 * channel, 1, 1)
def forward(self, x1, x2, x3):
x1_1 = x1
x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
* self.conv_upsample3(self.upsample(x2)) * x3
x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
x2_2 = self.conv_concat2(x2_2)
x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
x3_2 = self.conv_concat3(x3_2)
x = self.conv4(x3_2)
x = self.conv5(x)
return x
4.Holistic Attention Module整体注意力模块
- 该模块旨在扩大初始显著预测图的面积,进而指引提升边界不准确,结果不完整等问题。
- 当从注意力分支获得准确的显著性图时,该策略将有效的抑制特征的干扰。
- 相反,如果将干扰归类为显著性区域,则该策略导致异常分割结果。因此,需要提高初始显著性图的有效性。更具体的说,显著性目标的边缘信息可能被初始显著性图过滤掉,因为难以精确预测。另外,复杂场景中的一些对象很难被完全分割。因此提出了一个整体注意力模块,来扩大初始显著性图的覆盖范围。
这里表示一个有着高斯核k和零偏置的卷积操作,其中的表示一个归一化函数,来让blurred map的范围变为[0,1]。而MAX操作表示取最大值函数,这样可以使得趋向于增加平滑后的中显著性区域的权重系数,既保留了原始显著图显著区域的值(显著区域原始图值大于平滑模糊图值),同时提升了对原始显著图的边界区域的注意,扩大了显著感知的面积(在不显著区域,平滑后的模糊图值大于原始显著图值)。
将attention map 与第三层卷积特征作element-wise multiplying,得到注意力后的修正的特征。和第四层,第五层特征一起送入解码器部分产生新的显著预测图。相较于初始的注意力,提出的整体注意力机制增加了一定的计算消耗,但是也进一步高亮了整体显著性目标。
注意:这里的高斯核k的尺寸和标准差被初始化为32和4,在训练中会自动学习
class HA(nn.Module):
# holistic attention module
def __init__(self):
super(HA, self).__init__()
gaussian_kernel = np.float32(gkern(31, 4))
gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]
self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))
def forward(self, attention, x):
soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)
soft_attention = min_max_norm(soft_attention)
x = torch.mul(x, soft_attention.max(attention))
return x
def gkern(kernlen=16, nsig=3):
interval = (2*nsig+1.)/kernlen
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
kern1d = np.diff(st.norm.cdf(x))
kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
kernel = kernel_raw/kernel_raw.sum()
return kernel
def min_max_norm(in_):
max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
in_ = in_ - min_
return in_.div(max_-min_+1e-8)