前言
hrnet_ocr
是 Semantic Segmentation on Cityscapes test
中目前排名第一的语义分割模型,其将高分辨网络hrnet
和 OCRNet
方法相结合,本文主要介绍OCRNet
方法。
OCRNet
提出背景:使用一般性的ASPP
方法如图(a),其中红点是关注的点,蓝点和黄点是采样出来的周围点,若将其作为红点的上下文,背景和物体没有区分开来,这样的上下文信息对红点像素分类帮助有限。为改善此情况,提出OCRNet
方法如图(b),其可让上下文信息关注在物体上,从而为红点提供更有用的信息。
论文:https://arxiv.org/pdf/1909.11065.pdf
源码:https://git.io/openseg and https://git.io/HRNet.OCR
OCRNet 网络
OCRNet
方法总体思路:coarse-to-fine
的语义分割过程,首先用一般的语义分割模型得到一个粗略的分割结果,同时从backbone
还可获得每个像素的特征,根据每个像素的语义信息和特征,可以得到每个类别的特征;随后可计算像素特征与各个类别特征的相似度,根据该相似度可得到每个像素点属于各类别的可能性,进一步把每个区域的表征进行加权,会得到当前像素增强的特征表示(object-contextual representation)
,整体流程如下:
Step1:提取类别区域特征
目标:根据 像素语义信息 和 像素特征 得到每个 类别区域特征。其中像素语义信息是常规的语义分割结果,像素特征就是backbone
提取得到的特征图,具体做法如下:
(1)像素语义(20×100×100)展开成二维(20×10000),其每一行表示每个像素点(10000个像素点)属于某类物体(总共20个类)的概率。
(2)像素特征(512×100×100)展开成二维(512×10000),其每一列表示每个像素点(10000个像素点)在某一维特征(512维)。
(3)像素语义的每行乘以像素特征的每列再相加,得到类别区域特征,其每一行表示某个类(20类)的512维特征。
计算代码如下:
def get_proxy(feats,probs):
batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
# 1, 20, 100, 100
probs = probs.view(batch_size, c, -1)
# (1, 20, 10000)
feats = feats.view(batch_size, feats.size(1), -1)
# (1, 512, 10000)
feats = feats.permute(0, 2, 1) # batch x hw x c
# (1, 10000, 512)
probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw
# (1, 20, 10000)
proxy = torch.matmul(probs, feats).permute(0, 2, 1).unsqueeze(3)# batch x k x c
# (1, 512, 20, 1)
return proxy
if __name__ == "__main__":
feats = torch.randn((1, 512, 100, 100))
probs = torch.randn((1, 20, 100, 100))
proxy=get_proxy(feats,probs)
Step2:像素区域相似度
对像素特征 feats
和 step1
得到类别区域特征 proxy
,使用 self-attention
得到像素与区域的相似度,即依赖关系。
self-attention
中
Q
Q
Q,
K
K
K,
V
V
V 计算如下:
{
Q
=
f
_
p
i
x
e
l
(
f
e
a
t
s
)
K
=
f
_
o
b
j
e
c
t
(
p
r
o
x
y
)
V
=
f
_
d
o
w
n
(
p
r
o
x
y
)
\begin{cases} Q=f\_pixel(feats)\\ K=f\_object(proxy)\\ V=f\_down(proxy)\\ \end{cases} \\
⎩⎪⎨⎪⎧Q=f_pixel(feats)K=f_object(proxy)V=f_down(proxy)
f
_
p
i
x
e
l
f\_pixel
f_pixel 和
f
_
o
b
j
e
c
t
f\_object
f_object 代码如下:
f_pixel = nn.Sequential(
nn.Conv2d(in_ch=in_ch, out_ch=key_ch,kernel_size=1, stride=1, padding=0),
ModuleHelper.BNReLU(key_ch, bn_type=bn_type),
nn.Conv2d(in_ch=key_ch, out_ch=key_ch,kernel_size=1, stride=1, padding=0),
ModuleHelper.BNReLU(key_ch, bn_type=bn_type),
)
f_object = nn.Sequential(
nn.Conv2d(in_ch=in_ch, out_ch=key_ch,kernel_size=1, stride=1, padding=0),
ModuleHelper.BNReLU(key_ch, bn_type=bn_type),
nn.Conv2d(in_ch=key_ch, out_ch=key_ch,kernel_size=1, stride=1, padding=0),
ModuleHelper.BNReLU(key_ch, bn_type=bn_type),
)
根据
Q
Q
Q 和
K
K
K 得到像素与区域的依赖关系:
s
i
m
m
a
p
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
simmap=softmax(\frac {QK^T}{\sqrt{d_k}})
simmap=softmax(dkQKT)
计算代码如下:
def get_sim_map(feats, proxy):
x=feats
batch_size, h, w = x.size(0), x.size(2), x.size(3)
# 1, 100, 100
## qk
query = f_pixel(x).view(batch_size, self.key_channels, -1)
# (1, 256, 10000)
query = query.permute(0, 2, 1)
# (1, 256, 10000)
key = f_object(proxy).view(batch_size, self.key_channels, -1)
# (1, 256, 20)
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
# (1, 256, 20)
value = value.permute(0, 2, 1)
# (1, 20, 256)
## sim
sim_map = torch.matmul(query, key)
# (1, 10000, 20)
sim_map = (self.key_channels**-.5) * sim_map
# (1, 10000, 20)
sim_map = F.softmax(sim_map, dim=-1)
# (1, 10000, 20)
return sim_map
if __name__ == "__main__":
feats = torch.randn((1, 512, 100, 100))
proxy=get_proxy(feats,probs)
sim_map=get_sim_map(feats,proxy)
Step3:获得上下文表示
由step2
计算得到simmap
,其乘以V
则可context
,将context
和像素特征进行拼接,再做通道调整得到最终的上下文表示,计算公式如下:
c
o
n
t
e
x
t
=
s
i
m
m
a
p
×
V
c
o
n
t
e
x
t
=
c
o
n
v
_
b
_
d
r
o
p
o
u
t
(
t
o
r
c
h
.
c
a
t
(
[
c
o
n
t
e
x
t
,
f
e
a
t
s
]
,
1
)
)
context=simmap ×V \\ context=conv\_b\_dropout(torch.cat([context, feats], 1))
context=simmap×Vcontext=conv_b_dropout(torch.cat([context,feats],1))
计算代码如下:
def get_context(feats,proxy,sim_map):
context = torch.matmul(sim_map, value) # hw x k x k x c
# (1, 10000, 256)
context = context.permute(0, 2, 1).contiguous()
# (1, 10000, 256)
context = context.view(batch_size, self.key_channels, *x.size()[2:])
# (1, 256, 100, 100)
context = f_up(context)
# (1, 512, 100, 100)
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
# (1, 512, 100, 100)
return output
if __name__ == "__main__":
feats = torch.randn((1, 512, 100, 100))
proxy=get_proxy(feats,probs)
sim_map=get_sim_map(feats,proxy)
output=get_context(proxy,sim_map)
参考
Object-Contextual Representations for Semantic Segmentation