【论文笔记】CCNet阅读笔记

21 篇文章 8 订阅
17 篇文章 6 订阅

paper: CCNet: Criss-Cross Attention for Semantic Segmentation

github: https://github.com/speedinghzl/CCNet/tree/pure-python

上下文信息在语义分割任务中非常重要,CCNet提出了criss-cross attention模块,同时引入循环操作,使得图片中每个像素都可以和其他像素建立联系,从而使得每个像素都可以获得丰富的语义信息。此外提出了category consistent loss使得criss-cross attention模块能够产生更具有判别性的特征(category consistent loss代码本人未在源码中看到,本文将不关注此部分)。

一 、网络

1、网络结构

CCNet网络结构如下图所示,CNN表示特征提取器(backbone),Reduction减少特征图的通道数以减少后续计算量,Criss-Cross Attention用来建立不同位置像素间的联系从而丰富其语义信息,R表示Criss-Cross Attention Module的循环次数,注意多个Criss-Cross Attention Module共享参数

2、Criss-Cross Attention Module

Criss-Cross Attention Module结构如下图所示:

假设输入为X:[N, C, H, W],为了让一个像素与其他位置像素建立联系,首先在该像素的纵向和横向建立联系,以纵向为例:

①通过1x1卷积,得到 Q_h:[N, Cr, H, W],K_h:[N, Cr, H, W], V_h:[N, C, H, W](Q_w\K_w\V_w同理);

②维度变换,reshape得到 Q_h:[N * W,H,Cr],K_h:[N * W,Cr,H], V_h:[N * W,C,H];

③Q_h和K_h矩阵乘法,得到energy_h:[N * W, H, H];(源码中Enegy_H计算时加上了个维度为[N*W, H, H]的对角-inf矩阵,但是energy_w计算时没加,有点没搞懂。。)

④类似上面的流程,得到energy_h:[N * W, H, H]和energy_w:[N * H, W, W],reshape后维度变换得到energy_h:[N, H, W, H]和energy_w:[N, H, W, W],拼接得到energy:[N, H, W, H + W];

⑤在energy最后一个维度使用softmax,得到attention系数;

⑥将attention系数拆分为attn_h:[N, H, W, H]和attn_w:[N, H, W, W],维度变换后与V_h和V_w分别相乘得到输出out_h和out_w;

⑦将out_h+out_w,并乘上一个系数γ(可学习参数),再加上residual connection,得到最终输出。

代码如下:

def INF(B,H,W):
     return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)


class CrissCrossAttention(nn.Module):
    """ Criss-Cross Attention Module"""
    def __init__(self, in_dim):
        super(CrissCrossAttention,self).__init__()
        self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
        self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
        self.softmax = Softmax(dim=3)
        self.INF = INF
        self.gamma = nn.Parameter(torch.zeros(1))


    def forward(self, x):
        m_batchsize, _, height, width = x.size()
        proj_query = self.query_conv(x)
        proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
        proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
        proj_key = self.key_conv(x)
        proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        proj_value = self.value_conv(x)
        proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
        proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
        energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
        energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
        concate = self.softmax(torch.cat([energy_H, energy_W], 3))

        att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
        #print(concate)
        #print(att_H) 
        att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
        out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
        out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
        #print(out_H.size(),out_W.size())
        return self.gamma*(out_H + out_W) + x

再回到上面的结构图,经过Criss-Cross Attention模块后,每个像素与其横向和纵向所有像素建立联系,只需要再经过Criss-Cross Attention模块,每个像素就与其他所有像素建立了联系,从而丰富了语义信息。

PS:此外作者拓展了3D Criss-Cross Attention模块,此处不再介绍。

二、实验结果

1、训练策略

优化器:SGD(动量0.9,weight_decay 0.0001)

学习率:多项式策略1-(\frac{iter}{max\_iter})^{power}

数据增强:Random scaling (0.75-2.0)

                  Random cropping (cityscapes:769x769)

对ADE20K,使用resizebyshort,从{300, 375, 450, 525, 600}中选择一个值,将图片短边resize为该值。

2、RCCA

再cityscapes验证集上对CCA循环次数做了对比,可以看出,当R=2的时候,miou提升非常明显,当R=3时提升相对小很多。因为当R=1的时候,每个像素只能够得到其纵向和横向的语义信息,当R=2时,可以得到全局语义信息。

 3、performance

CCNet再cityscapes验证集和测试集上的表现如下:

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

justld

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值