Cascaded Partial Decoder for Fast and Accurate Salient Object Detection

CVPR2019发布,在看参考文献时发现的一篇显著性目标检测文章

有些博客讲的还是蛮详细的,放在下面了

Cascaded Partial Decoder for Fast and Accurate Salient Object Detectionicon-default.png?t=M4ADhttps://blog.csdn.net/bananalone/article/details/106881181?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522165413634116780357288483%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=165413634116780357288483&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-2-106881181-null-null.142^v11^control,157^v12^new_style2&utm_term=Cascaded+partial+decoder+for+fast+and+accurate+salient+object+detection&spm=1018.2226.3001.4187


然后自己也画了一下网络框图加深理解

目录

CVPR2019发布,在看参考文献时发现的一篇显著性目标检测文章

1.整体框图

2.RFB模块

3.aggregation模块

4.Holistic Attention Module整体注意力模块


1.整体框图

 

 2.RFB模块

class RFB(nn.Module):
    # RFB-like multi-scale module
    def __init__(self, in_channel, out_channel):
        super(RFB, self).__init__()
        self.relu = nn.ReLU(True)
        self.branch0 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
        )
        self.branch1 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
            BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
        )
        self.branch2 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
            BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
        )
        self.branch3 = nn.Sequential(
            BasicConv2d(in_channel, out_channel, 1),
            BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
            BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
            BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
        )
        self.conv_cat = BasicConv2d(4 * out_channel, out_channel, 3, padding=1)
        self.conv_res = BasicConv2d(in_channel, out_channel, 1)

    def forward(self, x):
        x0 = self.branch0(x)
        x1 = self.branch1(x)
        x2 = self.branch2(x)
        x3 = self.branch3(x)

        x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

        x = self.relu(x_cat + self.conv_res(x))
        return x

 3.aggregation模块

 

class aggregation(nn.Module):
    # dense aggregation, it can be replaced by other aggregation model, such as DSS, amulet, and so on.
    # used after MSF
    def __init__(self, channel):
        super(aggregation, self).__init__()
        self.relu = nn.ReLU(True)

        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv_upsample1 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample2 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample3 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample4 = BasicConv2d(channel, channel, 3, padding=1)
        self.conv_upsample5 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)

        self.conv_concat2 = BasicConv2d(2 * channel, 2 * channel, 3, padding=1)
        self.conv_concat3 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
        self.conv4 = BasicConv2d(3 * channel, 3 * channel, 3, padding=1)
        self.conv5 = nn.Conv2d(3 * channel, 1, 1)

    def forward(self, x1, x2, x3):
        x1_1 = x1
        x2_1 = self.conv_upsample1(self.upsample(x1)) * x2
        x3_1 = self.conv_upsample2(self.upsample(self.upsample(x1))) \
               * self.conv_upsample3(self.upsample(x2)) * x3

        x2_2 = torch.cat((x2_1, self.conv_upsample4(self.upsample(x1_1))), 1)
        x2_2 = self.conv_concat2(x2_2)

        x3_2 = torch.cat((x3_1, self.conv_upsample5(self.upsample(x2_2))), 1)
        x3_2 = self.conv_concat3(x3_2)

        x = self.conv4(x3_2)
        x = self.conv5(x)

        return x

4.Holistic Attention Module整体注意力模块

  • 该模块旨在扩大初始显著预测图的面积,进而指引提升边界不准确,结果不完整等问题。
  • 当从注意力分支获得准确的显著性图时,该策略将有效的抑制特征的干扰。
  • 相反,如果将干扰归类为显著性区域,则该策略导致异常分割结果。因此,需要提高初始显著性图的有效性。更具体的说,显著性目标的边缘信息可能被初始显著性图过滤掉,因为难以精确预测。另外,复杂场景中的一些对象很难被完全分割。因此提出了一个整体注意力模块,来扩大初始显著性图的覆盖范围。

        这里\mathit{Conv_{g}}(S_{i},k)表示一个有着高斯核k和零偏置的卷积操作,其中的\mathit{f}_{min(max)}()表示一个归一化函数,来让blurred map的范围变为[0,1]。而MAX操作表示取最大值函数,这样可以使得S_{i}趋向于增加平滑后的中显著性区域的权重系数,既保留了原始显著图显著区域的值(显著区域原始图值大于平滑模糊图值),同时提升了对原始显著图的边界区域的注意,扩大了显著感知的面积(在不显著区域,平滑后的模糊图值大于原始显著图值)。

        将attention map S_{h}与第三层卷积特征作element-wise multiplying,得到注意力后的修正的特征。和第四层,第五层特征一起送入解码器部分产生新的显著预测图。相较于初始的注意力,提出的整体注意力机制增加了一定的计算消耗,但是也进一步高亮了整体显著性目标。

注意:这里的高斯核k的尺寸和标准差被初始化为32和4,在训练中会自动学习
 

class HA(nn.Module):
    # holistic attention module
    def __init__(self):
        super(HA, self).__init__()
        gaussian_kernel = np.float32(gkern(31, 4))
        gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]
        self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))

    def forward(self, attention, x):
        soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)
        soft_attention = min_max_norm(soft_attention)
        x = torch.mul(x, soft_attention.max(attention))
        return x

def gkern(kernlen=16, nsig=3):
    interval = (2*nsig+1.)/kernlen
    x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
    kern1d = np.diff(st.norm.cdf(x))
    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
    kernel = kernel_raw/kernel_raw.sum()
    return kernel


def min_max_norm(in_):
    max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
    min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
    in_ = in_ - min_
    return in_.div(max_-min_+1e-8)

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值