OCRNet原理与代码解析(ECCV 2020)

paper:Object-Contextual Representations for Semantic Segmentation

official implementation:https://github.com/HRNet/HRNet-Semantic-Segmentation

third-party implementation:https://github.com/open-mmlab/mmsegmentation/blob/main/mmseg/models/decode_heads/ocr_head.py

本文的创新点

本文聚焦于用上下文聚合策略context aggregation strategy来处理语义分割问题。本文的启发来源于一个像素的类别应该是这个像素所属对象的类别。本文提出了一个简单有效的方法,对象-上下文表示object-contextual representations,通过利用对应对象的类别来描述一个像素。首先在ground-truth分割的监督下学习目标区域,然后通过聚合目标区域内像素的表示来计算目标区域的表示,最后计算每个像素与每个目标区域的关系并使用object-contextual representation来增强每个像素的表示,其中object-contextual representation是所有目标区域表示与像素关系的加权聚合。

方法介绍

像素 \(p_{i}\) 的类别 \(l_{i}\) 本质上是 \(p_{i}\) 所在对象的类别。对象-上下文表示包括:(1)将图像 \(I\) 中的所有像素结构化分为 \(K\) 个软对象区域(2)通过聚合第 \(k\) 个对象区域中所有像素的表示用 \(\mathbf{f}_{k}\) 来示每个目标区域(3)基于 \(K\) 个目标区域和所有目标区域的关系,通过聚合 \(K\) 个目标区域的表示来增强每个像素的表示

其中 \(\mathbf{f}_{k}\) 是第 \(k\) 个目标区域的表示,\(w_{ik}\) 表示第 \(i\) 个像素和第 \(k\) 个目标区域的关系,\(\delta(\cdot)\) 和 \(\rho (\cdot)\) 是变换函数。

Soft object regions

将图像 \(I\) 划分为 \(K\) 个软目标区域 \(\left \{ \mathbf{M}_{1},\mathbf{M}_{2},...,\mathbf{M}_{K} \right \} \),每个目标区域 \(\mathbf{M}_{k}\) 对应类别 \(k\),并由一个2D map或一个粗略的segmentation map来表示,其中每个值表示这个位置的像素属于类别 \(k\) 的程度。我们根据骨干网络的中间输出来计算这 \(K\) 个目标区域。在训练过程中,在ground-truth segmentation的监督下用交叉熵损失来学习目标区域。

Object region representations

我们根据所有像素对第 \(k\) 个目标区域的所属程度进行加权聚合,从而得到第 \(k\) 个目标区域的表示

其中 \(\mathbf{x}_{i}\) 表示像素 \(p_{i}\),\(\widetilde{m}_{ki} \) 表示像素 \(p_{i}\) 属于第 \(k\) 个目标区域的归一化后的程度。我们使用spatial softmax来归一化每个目标区域 \(\mathbf{M}_{k}\)。

Object contextual representations

我们按下式计算每个像素和每个目标区域的关系

其中 \(\kappa (\mathbf{x},\mathbf{f})=\phi(\mathbf{x})^{\mathsf{T} }\psi (\mathbf{f})\) 是未归一化的关系函数,\(\phi(\cdot)\) 和 \(\psi(\cdot)\) 是两个转换函数具体实现为1x1 conv —>BN —>ReLU。这里是受到了self-attention的启发。

像素 \(p_{i}\) 的object contextual representation \(\mathbf{y}_{i}\) 根据式(3)计算得到。其中 \(\rho(\cdot)\) 和 \(\delta(\cdot)\) 也是由1x1 conv —>BN —>ReLU实现的两个转换函数。这里是受到non-local networks的启发。

Augmented representations

像素 \(p_{i}\) 的最终表示由两部分组成,一是原始表示 \(\mathbf{x}_{i}\),二是对象-上下文表示 \(\mathbf{y}_{i}\)

其中 \(g(\cdot)\) 是由1x1 conv —>BN —>ReLU实现的转换函数,用于融合原始表示和对象上下文表示。

整个pipeline如下图所示

代码解析

这里以mmsegmentation中的实现为例,介绍一下实现代码。输入shape=(8, 3, 480, 480),backbone采用ResNet-50,配置如下 

可以看出和原始的ResNet-50不同的是,只有stage1的stride=2,因此经过backbone的输出shape为[(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)]。 

接下来head部分是一个级联head的设计,首先是一个FCNHead,然后是一个OCRHead,配置如下。

Head部分的实现如下

def _decode_head_forward_train(self, inputs: Tensor,
                               data_samples: SampleList) -> dict:
    """Run forward function and calculate loss for decode head in
    training."""
    losses = dict()

    loss_decode = self.decode_head[0].loss(inputs, data_samples,
                                           self.train_cfg)

    losses.update(add_prefix(loss_decode, 'decode_0'))
    # get batch_img_metas
    batch_size = len(data_samples)
    batch_img_metas = []
    for batch_index in range(batch_size):
        metainfo = data_samples[batch_index].metainfo
        batch_img_metas.append(metainfo)

    for i in range(1, self.num_stages):
        # forward test again, maybe unnecessary for most methods.
        if i == 1:
            prev_outputs = self.decode_head[0].forward(inputs)  # ocrnet_r50, (8,2,60,60)
        else:
            prev_outputs = self.decode_head[i - 1].forward(
                inputs, prev_outputs)
        loss_decode = self.decode_head[i].loss(inputs, prev_outputs,
                                               data_samples,
                                               self.train_cfg)
        losses.update(add_prefix(loss_decode, f'decode_{i}'))

    return losses

其中self.decode_head[0]就是FCNHead,对应图3中的粉色框。从图3和代码实现中可以看出,这里FCNHead既受GT的监督计算损失即line7的loss_decode,同时又输出Soft Object Regions即line21的prev_outputs作为输入送入OCRHead中。 

OCRHead的实现部分如下

def forward(self, inputs, prev_output):  # [(8,256,120,120),(8,512,60,60),(8,1024,60,60),(8,2048,60,60)], (8,2,60,60)
    """Forward function."""
    x = self._transform_inputs(inputs)  # (8,2048,60,60)
    feats = self.bottleneck(x)  # (8,512,60,60)
    context = self.spatial_gather_module(feats, prev_output)  # 式(4)得到f_{k}, prev_output就是M_{k}, (8,512,2,1)
    object_context = self.object_context_block(feats, context)  # (8,512,60,60)
    output = self.cls_seg(object_context)  # (8,2,60,60)

    return output

其中输入inputs就是backbone的输出,prev_output就是FCNHead的输出。首先self._transform_inputs对输入进行转换,配置文件中in_index=3,这里直接就是根据索引取值。

然后self.bottleneck就是一个3x3卷积。

接着self.spatial_gather_module的实现如下,对应的是式(4)。其中probs就是FCNHead得到的object region \(M_{k}\),通过F.softmax来实现spatial softmax,最终得到的就是式(4)的输出 \(\mathbf{f}_{k}\)。

class SpatialGatherModule(nn.Module):
    """Aggregate the context features according to the initial predicted
    probability distribution.

    Employ the soft-weighted method to aggregate the context.
    """

    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, feats, probs):  # (8,512,60,60),(8,2,60,60)
        """Forward function."""
        batch_size, num_classes, height, width = probs.size()
        channels = feats.size(1)
        probs = probs.view(batch_size, num_classes, -1)  # (8,2,3600)
        feats = feats.view(batch_size, channels, -1)  # (8,512,3600)
        # [batch_size, height*width, num_classes]
        feats = feats.permute(0, 2, 1)  # (8,3600,512)
        # [batch_size, channels, height*width]
        probs = F.softmax(self.scale * probs, dim=2)  # 式(4)中的spatial softmax, (8,2,3600)
        # [batch_size, channels, num_classes]
        ocr_context = torch.matmul(probs, feats)  # (8,2,512)
        ocr_context = ocr_context.permute(0, 2, 1).contiguous().unsqueeze(3)  # (8,512,2,1)
        return ocr_context

然后self.object_context_block是式(5)、式(3)、式(6)的具体实现,forward部分如下,先看第二步对应的就是式(6),其中self.bottleneck就是 \(g(\cdot)\),context就是最终得到的对象-上下文表示 \(\mathbf{y}_{i}\)。

def forward(self, query_feats, key_feats):
    """Forward function."""
    context = super().forward(query_feats, key_feats)
    output = self.bottleneck(torch.cat([context, query_feats], dim=1))  # 式(6),对应g(.)
    if self.query_downsample is not None:
        output = resize(query_feats)

    return output

然后看第一步,这里直接调用的是SelfAttentionBlock,forward部分如下

def forward(self, query_feats, key_feats):  # 式(5)中的x_{i}, f_{k}
    # ocrnet_r50_d8
    # (8,512,60,60), (8,512,2,1)
    """Forward function."""
    batch_size = query_feats.size(0)
    query = self.query_project(query_feats)  # \phi, (8,256,60,60)
    if self.query_downsample is not None:
        query = self.query_downsample(query)
    query = query.reshape(*query.shape[:2], -1)  # (8,256,3600)
    query = query.permute(0, 2, 1).contiguous()  # (8,3600,256)

    key = self.key_project(key_feats)  # \psi, (8,256,2,1)
    value = self.value_project(key_feats)  # 式(3)中的\delta, (8,256,2,1)
    if self.key_downsample is not None:
        key = self.key_downsample(key)
        value = self.key_downsample(value)
    key = key.reshape(*key.shape[:2], -1)  # (8,256,2)
    value = value.reshape(*value.shape[:2], -1)  # (8,256,2)
    value = value.permute(0, 2, 1).contiguous()  # (8,2,256)

    sim_map = torch.matmul(query, key)  # (8,3600,2)
    if self.matmul_norm:  # True
        sim_map = (self.channels**-.5) * sim_map  # 256,0.0625
    sim_map = F.softmax(sim_map, dim=-1)  # 用softmax来实现式(5),得到w_{ik}, (8,3600,2)

    context = torch.matmul(sim_map, value)  # (8,3600,256)
    context = context.permute(0, 2, 1).contiguous()  # (8,256,3600)
    context = context.reshape(batch_size, -1, *query_feats.shape[2:])  # (8,256,60,60)
    if self.out_project is not None:
        context = self.out_project(context)  # 式(3)中的\rho, (8,512,60,60)
    return context

其中self.query_project、self.key_project、self.value_project分别对应 \(\phi(\cdot),\psi(\cdot),\delta(\cdot)\),最后的self.out_project对应 \(\rho(\cdot)\),line24用F.softmax来实现式(5)得到 \(w_{ik}\)。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

00000cj

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值