参考代码:CCNet
1. 概述
导读:CNN网络中较大范围的依赖(long-range dependencies)可以捕捉到很多有用的上下文信息,这个特性在图像理解任务中具有重要作用(如分割)。文章在参考non-local设计理念的基础上使用在像素点位置十字交叉的方式进行attention操作,用以获取丰富的上下文信息,提出由CCA模块构建的CCNet(criss-cross Network)。文章的方法相比之前的non-local具有如下两个优点:
1)相比non-local在显存上的开销更小,之间差距差了11倍;
2)在计算开销上,相比non-local减少了85%的计算量;
文章的方法在Cityscapes和ADE20K数据集上mIoU分别达到了81.4和45.22。
文章网络attention模块设计的思路很大程度上是来自于non-local方法,文中在其基础上对计算效率和显存占用做了优化。下图是一个典型的non-local的计算流程图:
其在分辨率为 ( H ∗ W ) (H*W) (H∗W)下计算量可以大体描述为 O ( ( H ∗ W ) ∗ ( H ∗ W ) ) O((H*W)*(H*W)) O((H∗W)∗(H∗W))。对此文章使用十字型采样的方式减少对应的资源消耗,但是为了弥补简单十字型采样带来的表达不足,文章使用了参数共享的堆叠方式,从而得到下图中的attention计算流程:
则经过十字型采样之后整个的计算量大体变为了 O ( ( W + H − 1 ) ∗ H ∗ W ) O((W+H-1)*H*W) O((W+H−1)∗H∗W)。
在下表中比较了在相同baseline下non-local和CCA模块的性能比较,见下表所示:
2. 方法设计
2.1 网络结构
文章的整体网络结构比较简单,具体见下图所示:
输入的图像首先经过一个带有dilation convolution的卷积得到特征图 X X X(其stride=8),之后经过一个channel采样的卷积通道数下采样得到特征图 H H H,之后经过两个权值共享的CCA模块得到经过优化的特征图,之后送入后面的分割头得到最后的分割结果。
2.2 CCA模块
文章的CCA模块其具体结构见下图所示:
输入的特征图表示为 H ∈ R C ∗ W ∗ H H\in R^{C*W*H} H∈RC∗W∗H,之后分别经过两个 1 ∗ 1 1*1 1∗