paper:Segmentation Transformer: Object-Contextual Representations for Semantic Segmentation
github:https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/HRNet-OCR
语义分割任务中,像素所属的类别就是像素所在的对象的类别,能不能利用像素和其所属目标的关系呢?OCRNet提出了一个有效方法,利用像素所属的目标的上下文信息来提升像素的表征(representation)。(本文说的表征,通俗理解就是每个像素对应的embedding)
该方法分为3步:
(1)根据类别数生成soft object region(可以理解为粗糙的分割结果);
(2)通过整合目标的所有像素的表征来估计目标区域表征( object region representation);
(3)根据像素和其所在目标的关系,计算目标上下文表征(object-contextual representation),并用目标上下文表征增强像素表征;
目录
2、Object region representations
3、Object contextual representations
一、计算公式
公式中的函数实现:1x1 conv->BN->relu
1、Soft object regions
将输入图像I划分为K个Soft object regions。
简单理解,假设输入图像是[N, C, H, W], 那么Soft object regions维度就是[N, num_classes, H, W],其中的元素值表示了该像素属于某个类别的程度(其实就是语义分割头没有softmax的输出)。
2、Object region representations
首先给出论文的公式 如下:
乍看一脸懵,立马跑去看看源码,是softmax输出的概率值。
简单的理解:对于同一个目标的所有像素,需要学习一个embedding来表征。假设输入X维度为[N, C, H*W],Soft object regions维度为[N, num_classes, H*W],一个目标有多个像素,需要学习一个统一的表征,那么学习到的Object region representations维度为[N, num_classes, C]。
3、Object contextual representations
首先计算像素表征和目标表征的关系(表示第i个像素表征和第k个目标区域表征的关系):
利用下式计算目标上下文表征。
这里的计算过程和Attention相同:
可以把看作q,维度[N, H*W, C],看作k,维度为[N, C, num_classes],看作v,维度为[N, num_classes, C],y的维度为[N, H*W, C]。
4、Augmented representations
将目标上下文表征和像素表征结合起来,增强像素表征。
二、网络结构
OCRNet的网络结构如图所示。
设Pixel Representations维度为[N, C, H, W], Soft Object Regions维度为[N, num_classes, H, W]。
Object Region Representations: [N, num_classes, C]
Pixel Region Relation: [N, num_classes, H, W]
Object Contextual Representation: [N, C, H, W] (图中的去除pixel representations蓝色小方块的部分)
Augmented Representations: [N, 2C, H, W]
三、实验结果
OCRNet在各数据集上性能如下。
四、总结
语义分割任务中,同一物体的所有像素的embedding应该相似,OCRNet通过Attention机制显式地利用了这种关系来提升语义分割的效果。