paper: CCNet: Criss-Cross Attention for Semantic Segmentation
github: https://github.com/speedinghzl/CCNet/tree/pure-python
上下文信息在语义分割任务中非常重要,CCNet提出了criss-cross attention模块,同时引入循环操作,使得图片中每个像素都可以和其他像素建立联系,从而使得每个像素都可以获得丰富的语义信息。此外提出了category consistent loss使得criss-cross attention模块能够产生更具有判别性的特征(category consistent loss代码本人未在源码中看到,本文将不关注此部分)。
一 、网络
1、网络结构
CCNet网络结构如下图所示,CNN表示特征提取器(backbone),Reduction减少特征图的通道数以减少后续计算量,Criss-Cross Attention用来建立不同位置像素间的联系从而丰富其语义信息,R表示Criss-Cross Attention Module的循环次数,注意多个Criss-Cross Attention Module共享参数。
2、Criss-Cross Attention Module
Criss-Cross Attention Module结构如下图所示:
假设输入为X:[N, C, H, W],为了让一个像素与其他位置像素建立联系,首先在该像素的纵向和横向建立联系,以纵向为例:
①通过1x1卷积,得到 Q_h:[N, Cr, H, W],K_h:[N, Cr, H, W], V_h:[N, C, H, W](Q_w\K_w\V_w同理);
②维度变换,reshape得到 Q_h:[N * W,H,Cr],K_h:[N * W,Cr,H], V_h:[N * W,C,H];
③Q_h和K_h矩阵乘法,得到energy_h:[N * W, H, H];(源码中Enegy_H计算时加上了个维度为[N*W, H, H]的对角-inf矩阵,但是energy_w计算时没加,有点没搞懂。。)
④类似上面的流程,得到energy_h:[N * W, H, H]和energy_w:[N * H, W, W],reshape后维度变换得到energy_h:[N, H, W, H]和energy_w:[N, H, W, W],拼接得到energy:[N, H, W, H + W];
⑤在energy最后一个维度使用softmax,得到attention系数;
⑥将attention系数拆分为attn_h:[N, H, W, H]和attn_w:[N, H, W, W],维度变换后与V_h和V_w分别相乘得到输出out_h和out_w;
⑦将out_h+out_w,并乘上一个系数γ(可学习参数),再加上residual connection,得到最终输出。
代码如下:
def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(nn.Module):
""" Criss-Cross Attention Module"""
def __init__(self, in_dim):
super(CrissCrossAttention,self).__init__()
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim//8, kernel_size=1)
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
self.softmax = Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
m_batchsize, _, height, width = x.size()
proj_query = self.query_conv(x)
proj_query_H = proj_query.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height).permute(0, 2, 1)
proj_query_W = proj_query.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width).permute(0, 2, 1)
proj_key = self.key_conv(x)
proj_key_H = proj_key.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_key_W = proj_key.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
proj_value = self.value_conv(x)
proj_value_H = proj_value.permute(0,3,1,2).contiguous().view(m_batchsize*width,-1,height)
proj_value_W = proj_value.permute(0,2,1,3).contiguous().view(m_batchsize*height,-1,width)
energy_H = (torch.bmm(proj_query_H, proj_key_H)+self.INF(m_batchsize, height, width)).view(m_batchsize,width,height,height).permute(0,2,1,3)
energy_W = torch.bmm(proj_query_W, proj_key_W).view(m_batchsize,height,width,width)
concate = self.softmax(torch.cat([energy_H, energy_W], 3))
att_H = concate[:,:,:,0:height].permute(0,2,1,3).contiguous().view(m_batchsize*width,height,height)
#print(concate)
#print(att_H)
att_W = concate[:,:,:,height:height+width].contiguous().view(m_batchsize*height,width,width)
out_H = torch.bmm(proj_value_H, att_H.permute(0, 2, 1)).view(m_batchsize,width,-1,height).permute(0,2,3,1)
out_W = torch.bmm(proj_value_W, att_W.permute(0, 2, 1)).view(m_batchsize,height,-1,width).permute(0,2,1,3)
#print(out_H.size(),out_W.size())
return self.gamma*(out_H + out_W) + x
再回到上面的结构图,经过Criss-Cross Attention模块后,每个像素与其横向和纵向所有像素建立联系,只需要再经过Criss-Cross Attention模块,每个像素就与其他所有像素建立了联系,从而丰富了语义信息。
PS:此外作者拓展了3D Criss-Cross Attention模块,此处不再介绍。
二、实验结果
1、训练策略
优化器:SGD(动量0.9,weight_decay 0.0001)
学习率:多项式策略
数据增强:Random scaling (0.75-2.0)
Random cropping (cityscapes:769x769)
对ADE20K,使用resizebyshort,从{300, 375, 450, 525, 600}中选择一个值,将图片短边resize为该值。
2、RCCA
再cityscapes验证集上对CCA循环次数做了对比,可以看出,当R=2的时候,miou提升非常明显,当R=3时提升相对小很多。因为当R=1的时候,每个像素只能够得到其纵向和横向的语义信息,当R=2时,可以得到全局语义信息。
3、performance
CCNet再cityscapes验证集和测试集上的表现如下: