paper:Object-Contextual Representations for Semantic Segmentation
official implementation:https://github.com/HRNet/HRNet-Semantic-Segmentation
third-party implementation:https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/ocr_head.py
本文的创新点
本文聚焦于用上下文聚合策略context aggregation strategy来处理语义分割问题。本文的启发来源于一个像素的类别应该是这个像素所属对象的类别。本文提出了一个简单有效的方法,对象-上下文表示object-contextual representations,通过利用对应对象的类别来描述一个像素。首先在ground-truth分割的监督下学习目标区域,然后通过聚合目标区域内像素的表示来计算目标区域的表示,最后计算每个像素与每个目标区域的关系并使用object-contextual representation来增强每个像素的表示,其中object-contextual representation是所有目标区域表示与像素关系的加权聚合。
方法介绍
像素 \(p_{i}\) 的类别 \(l_{i}\) 本质上是 \(p_{i}\) 所在对象的类别。对象-上下文表示包括:(1)将图像 \(I\) 中的所有像素结构化分为 \(K\) 个软对象区域(2)通过聚合第 \(k\) 个对象区域中所有像素的表示用 \(\mathbf{f}_{k}\) 来示每个目标区域(3)基于 \(K\) 个目标区域和所有目标区域的关系,通过聚合 \(K\) 个目标区域的表示来增强每个像素的表示
其中 \(\mathbf{f}_{k}\) 是第 \(k\) 个目标区域的表示,\(w_{ik}\) 表示第 \(i\) 个像素和第 \(k\) 个目标区域的关系,\(\delta(\cdot)\) 和 \(\rho (\cdot)\) 是变换函数。
Soft object regions
将图像 \(I\) 划分为 \(K\) 个软目标区域 \(\left \{ \mathbf{M}_{1},\mathbf{M}_{2},...,\mathbf{M}_{K} \right \} \),每个目标区域 \(\mathbf{M}_{k}\) 对应类别 \(k\),并由一个2D map或一个粗略的segmentation map来表示,其中每个值表示这个位置的像素属于类别 \(k\) 的程度。我们根据骨干网络的中间输出来计算这 \(K\) 个目标区域。在训练过程中,在ground-truth segmentation的监督下用交叉熵损失来学习目标区域。
Object region representations
我们根据所有像素对第 \(k\) 个目标区域的所属程度进行加权聚合,从而得到第 \(k\) 个目标区域的表示
其中 \(\mathbf{x}_{i}\) 表示像素 \(p_{i}\),\(\widetilde{m}_{ki} \) 表示像素 \(p_{i}\) 属于第 \(k\) 个目标区域的归一化后的程度。我们使用spatial softmax来归一化每个目标区域 \(\mathbf{M}_{k}\)。
Object contextual representations
我们按下式计算每个像素和每个目标区域的关系
其中 \(\kappa (\mathbf{x},\mathbf{f})=\phi(\mathbf{x})^{\mathsf{T} }\psi (\mathbf{f})\) 是未归一化的关系函数,\(\phi(\cdot)\) 和 \(\psi(\cdot)\) 是两个转换函数具体实现为1x1 conv —>BN —>ReLU。这里是受到了self-attention的启发。
像素 \(p_{i}\) 的object contextual representation \(\mathbf{y}_{i}\) 根据式(3)计算得到。其中 \(\rho(\cdot)\) 和 \(\delta(\cdot)\) 也是由1x1 conv —>BN —>ReLU实现的两个转换函数。这里是受到non-local networks的启发。
Augmented representations
像素 \(p_{i}\) 的最终表示由两部分组成,一是原始表示 \(\mathbf{x}_{i}\),二是对象-上下文表示 \(\mathbf{y}_{i}\)
其中 \(g(\cdot)\) 是由1x1 conv —>BN —>ReLU实现的转换函数,用于融合原始表示和对象上下文表示。
整个pipeline如下图所示
代码解析
这里以mmsegmentation中的实现为例,介绍一下实现代码。输入shape=(8, 3, 480, 480),backbone采用ResNet-50,配置如下
可以看出和原始的ResNet-50不同的是,只有stage1的stride=2,因此经过backbone的输出shape为[(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)]。
接下来head部分是一个级联head的设计,首先是一个FCNHead,然后是一个OCRHead,配置如下。
Head部分的实现如下
def _decode_head_forward_train(self, inputs: Tensor,
data_samples: SampleList) -> dict:
"""Run forward function and calculate loss for decode head in
training."""
losses = dict()
loss_decode = self.decode_head[0].loss(inputs, data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, 'decode_0'))
# get batch_img_metas
batch_size = len(data_samples)
batch_img_metas = []
for batch_index in range(batch_size):
metainfo = data_samples[batch_index].metainfo
batch_img_metas.append(metainfo)
for i in range(1, self.num_stages):
# forward test again, maybe unnecessary for most methods.
if i == 1:
prev_outputs = self.decode_head[0].forward(inputs) # ocrnet_r50, (8,2,60,60)
else:
prev_outputs = self.decode_head[i - 1].forward(
inputs, prev_outputs)
loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
data_samples,
self.train_cfg)
losses.update(add_prefix(loss_decode, f'decode_{i}'))
return losses
其中self.decode_head[0]就是FCNHead,对应图3中的粉色框。从图3和代码实现中可以看出,这里FCNHead既受GT的监督计算损失即line7的loss_decode,同时又输出Soft Object Regions即line21的prev_outputs作为输入送入OCRHead中。
OCRHead的实现部分如下
def forward(self, inputs, prev_output): # [(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)], (8,2,60,60)
"""Forward function."""
x = self._transform_inputs(inputs) # (8,2048,60,60)
feats = self.bottleneck(x) # (8,512,60,60)
context = self.spatial_gather_module(feats, prev_output) # 式(4)得到f_{k}, prev_output就是M_{k}, (8,512,2,1)
object_context = self.object_context_block(feats, context) # (8,512,60,60)
output = self.cls_seg(object_context) # (8,2,60,60)
return output
其中输入inputs就是backbone的输出,prev_output就是FCNHead的输出。首先self._transform_inputs对输入进行转换,配置文件中in_index=3,这里直接就是根据索引取值。
然后self.bottleneck就是一个3x3卷积。
接着self.spatial_gather_module的实现如下,对应的是式(4)。其中probs就是FCNHead得到的object region \(M_{k}\),通过F.softmax来实现spatial softmax,最终得到的就是式(4)的输出 \(\mathbf{f}_{k}\)。
class SpatialGatherModule(nn.Module):
"""Aggregate the context features according to the initial predicted
probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, feats, probs): # (8,512,60,60),(8,2,60,60)
"""Forward function."""
batch_size, num_classes, height, width = probs.size()
channels = feats.size(1)
probs = probs.view(batch_size, num_classes, -1) # (8,2,3600)
feats = feats.view(batch_size, channels, -1) # (8,512,3600)
# [batch_size, height*width, num_classes]
feats = feats.permute(0, 2, 1) # (8,3600,512)
# [batch_size, channels, height*width]
probs = F.softmax(self.scale * probs, dim=2) # 式(4)中的spatial softmax, (8,2,3600)
# [batch_size, channels, num_classes]
ocr_context = torch.matmul(probs, feats) # (8,2,512)
ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3) # (8,512,2,1)
return ocr_context
然后self.object_context_block是式(5)、式(3)、式(6)的具体实现,forward部分如下,先看第二步对应的就是式(6),其中self.bottleneck就是 \(g(\cdot)\),context就是最终得到的对象-上下文表示 \(\mathbf{y}_{i}\)。
def forward(self, query_feats, key_feats):
"""Forward function."""
context = super().forward(query_feats, key_feats)
output = self.bottleneck(torch.cat([context, query_feats], dim=1)) # 式(6),对应g(.)
if self.query_downsample is not None:
output = resize(query_feats)
return output
然后看第一步,这里直接调用的是SelfAttentionBlock,forward部分如下
def forward(self, query_feats, key_feats): # 式(5)中的x_{i}, f_{k}
# ocrnet_r50_d8
# (8,512,60,60), (8,512,2,1)
"""Forward function."""
batch_size = query_feats.size(0)
query = self.query_project(query_feats) # \phi, (8,256,60,60)
if self.query_downsample is not None:
query = self.query_downsample(query)
query = query.reshape(*query.shape[:2], -1) # (8,256,3600)
query = query.permute(0, 2, 1).contiguous() # (8,3600,256)
key = self.key_project(key_feats) # \psi, (8,256,2,1)
value = self.value_project(key_feats) # 式(3)中的\delta, (8,256,2,1)
if self.key_downsample is not None:
key = self.key_downsample(key)
value = self.key_downsample(value)
key = key.reshape(*key.shape[:2], -1) # (8,256,2)
value = value.reshape(*value.shape[:2], -1) # (8,256,2)
value = value.permute(0, 2, 1).contiguous() # (8,2,256)
sim_map = torch.matmul(query, key) # (8,3600,2)
if self.matmul_norm: # True
sim_map = (self.channels**-.5) * sim_map # 256,0.0625
sim_map = F.softmax(sim_map, dim=-1) # 用softmax来实现式(5),得到w_{ik}, (8,3600,2)
context = torch.matmul(sim_map, value) # (8,3600,256)
context = context.permute(0, 2, 1).contiguous() # (8,256,3600)
context = context.reshape(batch_size, -1, *query_feats.shape[2:]) # (8,256,60,60)
if self.out_project is not None:
context = self.out_project(context) # 式(3)中的\rho, (8,512,60,60)
return context
其中self.query_project、self.key_project、self.value_project分别对应 \(\phi(\cdot),\psi(\cdot),\delta(\cdot)\),最后的self.out_project对应 \(\rho(\cdot)\),line24用F.softmax来实现式(5)得到 \(w_{ik}\)。