【小白入门】超详细的OCRnet详解(含代码分析)
本文仅梳理总结自己在学习过程中的一些理解和思路,不保证绝对正确,请酌情参考。如果各位朋友发现任何错误请及时告诉我,大家一起讨论共同提高。
本文参考博客https://blog.csdn.net/u011622208/article/details/110202659 及http://yearing1017.cn/2020/09/07/OCRNet/,有部分内容为直接引用,如有侵权,烦请告知(我删的贼快)
论文:https://arxiv.org/pdf/1909.11065.pdf
已开源代码地址:https://github.com/HRNet/HRNet-Semantic-Segmentation/tree/HRNet-OCR
OCRnet
针对语义分割中如何构建上下文信息,微软亚洲研究院和中科院计算所的研究员们提出了一种新的物体上下文信息——在构建上下文信息时显式地增强了来自于同一类物体的像素的贡献,这种新的上下文信息从语义分割的定义出发,符合第一性原理思维,在2019年7月和2020年1月的 Cityscapes leaderboard 提交结果中都取得了语义分割任务第一名的成绩。相关工作“Object-Contextual Representations for Semantic Segmentation”已经被 ECCV 2020 收录。
简介
对于语义分割任务来说,其两大关键是:分辨率和上下文。
- 语义分割一个密集像素预测任务,因此空间分辨率很重要。
- 像素本身不具备语义,它的语义由其图像整体或目标区域决定,因此它对上下文高度依赖。
- 一个像素位置的上下文指的是它周围的像素位置。
该论文的主要思想也就是像素的类别标签是由它所在的目标的类别标签决定的。主要思路是利用目标区域表示来增强其像素的表示。与之前的考虑上下文关系的方法不同的是,之前的方法考虑的是上下文像素之间的关系,没有显示利用目标区域的特征。
网络结构
官方给出的网络结构图如图
其中,粉红色虚线框内为形成的软对象区域(Soft Object Regions),紫色虚线框中为物体区域表示(Object Region Representations),橙色虚线框中为对象上下文表示和增强表示。
具体实现(含代码分析)
论文中指出,OCR 方法的实现主要包括3个阶段:
First, we divide the contextual pixels into a set of soft object regions with each corresponding to a class, i.e., a coarse soft segmentation computed from a deep network (e.g., ResNet [23] or HRNet [55]). Such division is learned under the supervision of the ground-truth segmentation.
第一步: 将上下文像素划分为一组软对象区域,每个soft object regions对应一个类,即从深度网络(backbone)计算得到的粗软分割(粗略的语义分割结果)。这种划分是在ground-truth分割的监督下学习的。根据网络中间层的特征表示估测粗略的语义分割结果作为 OCR 方法的一个输入,即结构图中粉红色框内的Soft Object Regions
self.aux_head = nn.Sequential(
nn.Conv2d(high_level_ch, high_level_ch,
kernel_size=1, stride=1, padding=0),
BNReLU(high_level_ch),
nn.Conv2d(high_level_ch, num_classes,
kernel_size=1, stride=1, padding=0, bias=True)
)
aux_out = self.aux_head(high_level_features) #soft object regions
#high_level_features为backbone输出的粗略的高层特征结果
将backbone的输出结果经过1*1的卷积后输出b×k×h×w的张量作为软对象区域,其中,k为粗略分类后对象的类别数(eg:若有17个类别,则网络输出为:b×17×h×w)
Second, we estimate the representation for each object region by aggregating the representations of the pixels in the corresponding object region.
第二步: 根据粗略的语义分割结果(soft object regions)和网络最深层输出的像素特征(Pixel Representations)表示计算出 K 组向量,即物体区域表示(Object Region Representations),其中每一个向量对应一个语义类别的特征表示
经评论区@duoheshuiya,@qq_39680835两位大佬指正,此步aux-out出来的c应该是类别数,k为预设的通道数。非常感谢两位的纠正,也向修改之前被我误导各位表示歉意
def forward(self, feats, probs):
batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(2), \
probs.size(3)
# each class image now a vector
probs = probs.view(batch_size, c, -1)
feats = feats.view(batch_size, feats.size(1), -1)
feats =