1. Title
Learning Position and Target Consistency for Memory-based Video Object Segmentation
2. Summary
本文主要解决的问题领域是半监督VOS,而在半监督VOS领域中,基于匹配的方法目前取得了较好的结果,但是这种方法也存在着一些缺陷:没有考虑帧间的时序顺序信息、缺乏对目标整体信息的利用等。
基于以上观察,本文首先是搭建了一个基于匹配的方法常用的Memory-Matching的框架,利用Matching方法的高召回率的特性,完成大部分像素的匹配和识别。
在此基础上,本文额外引入了两个模块用于获取相邻帧之间的位置信息以及待分割目标的整体特性。
模型整体根据交互关系可以分为以下三个部分:
- 当前帧与之前所有帧的交互
这一步主要是由Global Retrieval Module(GRM)完成。其本质上就是一个STM网络,计算当前帧像素与之前帧在时域和空域上像素的相似度,然后召回相应的像素。 - 当前帧与前一帧的交互
这一步主要是由Position Guidance Module(PGM)完成。考虑到相邻帧之间的位置信息较为相似,因此,通过对相邻帧的对比,可以对当前预测结果的位置进行一定的约束。 - 当前帧与第一帧的交互
这一步主要是通过Object Relation Module(ORM)完成。第一帧的mask是最为可靠和准确的,因此,可以利用其得到一个Object-Level的特征,在整个过程中对预测的mask进行约束,从而缓解Memory中的累计误差问题。
3. Problem Statement
本文研究的问题是半监督视频目标分割(Semi-supervised Video Object Segmentation)问题。
该问题描述如下:
对于一个待分割的视频序列,给出第一帧的Ground Truth Mask,算法需要根据第一帧标注的对象,将其在后续视频序列中分割出来。
在视频序列中,由于待分割目标的外观可能会由于运动、摄像视角改变、遮挡等问题出现较大的变化,同时视频中可能还存在与目标外观相似的不同实例需要进行区分,这些是VOS问题的难点。
目前基于memory的方法在VOS领域被广泛应用,其主要基于空间和时域上的像素间的匹配来完成目标的检测和分割。
这种基于memory的方法存在一些问题:
- 没有考虑帧间的时序顺序
- 缺乏对目标整体信息的使用
上述问题会导致对于不同帧中具有相似像素特性的非目标物体的识别会存在问题,同一类别的不同实例都可能会被匹配并被识别为待检测目标。
4. Method(s)
基于以上观察,本文提出了一个Learn position and target Consistency framework for Memory-based video object segmentation(LCM)的框架,其整体基于memory机制用于进行全局像素召回,与此同时,其会学习一种位置一致性约束,从而获得更为可信的分割结果,同时为了提升模型对于误差漂移现象的鲁棒性,还引入了一个整体一致性约束。
4.1 Overview
LCM使用一个Encoder-Decoder结构完成分割。
对于current frame,会得到三个Embeddings:
K
e
y
−
G
,
K
e
y
−
L
,
V
a
l
u
e
Key{-}G,Key{-}L,Value
Key−G,Key−L,Value,这三个Embeddings会被充分应用于三个模块:
- Global Retrieval Module(GRM)
GRM整体设计与STM一致,其用于计算current frame与memory pool的像素级的特征相关性。Memory Pool中存储的是由Memory Encoder产生的之前帧的 K e y − G , V a l u e Key{-}G,Value Key−G,Value。 - Position Guidance Module(PGM)
由于前一帧与当前帧之间存在相似的位置信息,因此,PGM将用于获取相邻帧之间的位置关系,用于对召回的像素进行位置上的约束。 - Object Relation Module(ORM)
更进一步,为了将Object-Level的信息融入到Pixel-Level Matching过程中,从而尽可能抑制Memory Pool中的累计错误,ORM用于获取Object-Level的特征。
第一帧的信息由于是最为可靠的,因此,其在推理过程中会全程得到保留。
4.2 Global Retrieval Module
GRM整体与STM类似,之前的frames及其预测的masks都会通过Memory Encoder进行编码,并存入Memory Pool中,而当前frame则通过Query Encoder进行编码。所有Encoder均为ResNet-50。
对于第
t
t
t帧,其输出的特征图分别为
r
M
∈
R
H
×
W
×
C
r^{M} \in \mathbb{R}^{H \times W \times C}
rM∈RH×W×C 和
r
Q
∈
R
H
×
W
×
C
r^{Q} \in \mathbb{R}^{H \times W \times C}
rQ∈RH×W×C。
对于之前的帧,其输出的特征图
r
M
r^{M}
rM会经过两个不同的3*3卷积分别进行embed,得到 Memory Global Key
k
M
∈
R
H
×
W
×
C
/
8
k^{M} \in \mathbb{R}^{H \times W \times C / 8}
kM∈RH×W×C/8和 Memory Value
v
M
∈
R
H
×
W
×
C
/
2
v^{M} \in \mathbb{R}^{H \times W \times C / 2}
vM∈RH×W×C/2,之前的
T
T
T帧的Memory Global Key和Memory Value会被存储入Memory Pool中,并在时间维度上进行拼接,得到
k
p
M
∈
R
T
×
H
×
W
×
C
/
8
k_{p}^{M} \in \mathbb{R}^{T \times H \times W \times C / 8}
kpM∈RT×H×W×C/8 和
v
p
M
∈
R
T
×
H
×
W
×
C
/
2
v_{p}^{M} \in \mathbb{R}^{T \times H \times W \times C / 2}
vpM∈RT×H×W×C/2。
对于Query Image也就是current frame,Query Global Key
k
Q
∈
R
H
×
W
×
C
/
8
k^{Q} \in \mathbb{R}^{H \times W \times C / 8}
kQ∈RH×W×C/8 会同样使用卷积从
r
Q
r^Q
rQ中得到。
GRM会基于Query Frame与Memory Pool的Global Key之间的相似性,进行匹配像素的召回:
s
(
i
,
j
)
=
exp
(
k
p
M
(
i
)
⊙
k
Q
(
j
)
⊤
)
∑
i
exp
(
k
p
M
(
i
)
⊙
k
Q
(
j
)
⊤
)
s(i, j)=\frac{\exp \left(k_{p}^{M}(i) \odot k^{Q}(j)^{\top}\right)}{\sum_{i} \exp \left(k_{p}^{M}(i) \odot k^{Q}(j)^{\top}\right)}
s(i,j)=∑iexp(kpM(i)⊙kQ(j)⊤)exp(kpM(i)⊙kQ(j)⊤)
其中,
i
i
i和
j
j
j为Memory Pool和Query Pixel Feature的索引值,
⊙
\odot
⊙表示矩阵内积,函数
s
s
s表示Softmax操作。
最后,召回的value feature通过以下方式计算得到:
y
G
R
M
(
j
)
=
∑
i
s
(
i
,
j
)
⊙
v
p
M
(
i
)
y^{G R M}(j)=\sum_{i} s(i, j) \odot v_{p}^{M}(i)
yGRM(j)=i∑s(i,j)⊙vpM(i)
GRM的核心贡献是其具有较高的召回率,但是其也存在着一些缺陷,Correlation Map的计算过程中并没有考虑位置一致性,所有的特征都是被同等对待。因此网络知道去寻找相似的区域,但是却不能对目标物体进行追踪。
4.3 Position Guidance Module
PGM通过对相邻帧进行编码,学习Position Consistency。
Encoder除了会输出Global Key,同时还会基于Res4 Feature Map产生一个Local Key,用于学习Local Position。
具体而言,另一个3*3卷积层将用于从Query Embedding和前一帧Memory Embedding中产生 Query Local Key k L Q ∈ R H × W × C / 8 k_{L}^{Q} \in \mathbb{R}^{H \times W \times C / 8} kLQ∈RH×W×C/8 和 Memory Local Key k I M ∈ R H × W × C / 8 k_{I}^{M} \in \mathbb{R}^{H \times W \times C / 8} kIM∈RH×W×C/8。
由于相似度的计算对位置具有不变性,因此,在PGM中额外引入了Positional Encoding用于引入位置信息。本文的Position Encoding采用的是不同频率的Sine和Cosine函数作为固定的位置编码。具体过程为:
p
M
(
i
)
=
f
n
(
k
L
M
(
i
)
+
pos
(
i
)
)
p
Q
(
j
)
=
f
n
(
k
L
Q
(
j
)
+
pos
(
i
)
)
\begin{aligned} p^{M}(i) &=f_{n}\left(k_{L}^{M}(i)+\operatorname{pos}(i)\right) \\ p^{Q}(j) &=f_{n}\left(k_{L}^{Q}(j)+\operatorname{pos}(i)\right) \end{aligned}
pM(i)pQ(j)=fn(kLM(i)+pos(i))=fn(kLQ(j)+pos(i))
随后通过矩阵乘法和Softmax函数用于得到一个针对前一帧各个位置的Response Distribution Map,除此之外,再引入前一帧的预测mask,进一步对非目标区域进行抑制:
S
(
i
,
j
)
=
exp
(
p
Q
(
j
)
⊙
p
M
(
i
)
⊤
)
∑
j
exp
(
p
Q
(
j
)
⊙
p
M
(
i
)
⊤
)
∗
g
(
M
t
−
1
)
S(i, j)=\frac{\exp \left(p^{Q}(j) \odot p^{M}(i)^{\top}\right)}{\sum_{j} \exp \left(p^{Q}(j) \odot p^{M}(i)^{\top}\right)} * g\left(M_{t-1}\right)
S(i,j)=∑jexp(pQ(j)⊙pM(i)⊤)exp(pQ(j)⊙pM(i)⊤)∗g(Mt−1)
g
(
x
)
=
exp
(
x
)
e
g(x)=\frac{\exp (x)}{e}
g(x)=eexp(x)用于避免由于预测错误背景产生的接近于0的响应对结果的过大干扰。
接着通过选取Memory Dimension的top-K个值,并取其平均值,得到一个Position Map,大小为
H
×
W
H \times W
H×W,并作为最终的Spatial Attention Map,用于对Query Value进行加权:
y
P
G
M
(
j
)
=
∑
i
top
K
{
S
(
i
,
j
)
}
K
∗
v
Q
(
j
)
y^{P G M}(j)=\frac{\sum_{i} \operatorname{top} K\{S(i, j)\}}{K} * v^{Q}(j)
yPGM(j)=K∑itopK{S(i,j)}∗vQ(j)
4.4 Object Relation Module
基于匹配的方法是一种自底向上的方法,缺乏上下文语义信息,Memory Pool中的累积错误可能会影响后续的匹配和位置关系的获取。
由于第一帧的标注是准确、可靠的,因此,本文提出了一个Object Relation Module,用于从第一帧中获取Object-Level的信息,以此作为先验信息,使得整个视频的推理过程中尽可能保持目标的一致性。
根据Ground Truth Mask,提取第一帧的value
v
F
v^F
vF 中的的前景特征得到一个value set
F
{
f
i
}
F\{f_i\}
F{fi},
i
i
i表示属于某个前景目标的位置。
ORM使用一个Cross Relation机制用于将Object-Level的特征融合入Query Value中:
F
Q
{
(
f
i
)
}
=
1
d
∑
j
f
(
F
{
f
i
}
,
v
Q
(
j
)
)
∗
g
(
v
Q
(
j
)
)
v
F
Q
(
j
)
=
1
d
∑
j
f
(
v
Q
(
j
)
,
F
{
f
i
}
)
∗
g
(
F
{
f
i
}
)
\begin{array}{c} F_{Q}\left\{\left(f_{i}\right)\right\}=\frac{1}{d} \sum_{j} f\left(F\left\{f_{i}\right\}, v^{Q}(j)\right) * g\left(v^{Q}(j)\right) \\ v_{F}^{Q}(j)=\frac{1}{d} \sum_{j} f\left(v^{Q}(j), F\left\{f_{i}\right\}\right) * g\left(F\left\{f_{i}\right\}\right) \end{array}
FQ{(fi)}=d1∑jf(F{fi},vQ(j))∗g(vQ(j))vFQ(j)=d1∑jf(vQ(j),F{fi})∗g(F{fi})
其中,
d
=
H
∗
W
d=H*W
d=H∗W是归一化系数,
g
g
g是一个1*1卷积,
f
f
f是两个向量间的点乘操作。
最后,增强后的第一帧的特征将会采用类似于SENet的类似的方法,对原始特征进行通道域上的增强:
v
Q
(
j
)
=
v
Q
(
j
)
+
v
F
Q
(
j
)
F
{
f
i
}
=
F
{
f
i
}
+
F
Q
{
(
f
i
)
}
y
O
R
M
(
j
)
=
v
Q
(
j
)
∗
G
A
P
(
F
{
f
i
}
)
\begin{array}{c} v^{Q}(j)=v^{Q}(j)+v_{F}^{Q}(j) \\ F\left\{f_{i}\right\}=F\left\{f_{i}\right\}+F_{Q}\left\{\left(f_{i}\right)\right\} \\ y^{O R M}(j)=v^{Q}(j) * G A P\left(F\left\{f_{i}\right\}\right) \end{array}
vQ(j)=vQ(j)+vFQ(j)F{fi}=F{fi}+FQ{(fi)}yORM(j)=vQ(j)∗GAP(F{fi})
4.5 Training Strategy
(1)Pre-training on Static Images
为了得到更好的参数初始化,本文首先在用静态图片和视频人工合成的数据集上进行预训练,每张图片通过使用仿射变换得到额外两张假图。
(2)Main-Training on Real Videos without Temporal Limit
不同于之前的模型设定,本文在训练过程中并没有限定采样帧之间的间隔,所有的帧均从随机打乱后的视频帧中采样得到,只有在三帧图片中均出现的前景目标才会作为实际前景目标。
(3)Fine-tuning on Real Videos as Sequence
在推理阶段,mask结果是逐帧逐帧推理得到的,为了降低训练和测试之间的gap,本文在连续的视频帧上对网络进行了进一步的微调。
5. Evaluation
5.1 对比实验
本文在DAVIS、YouTubeVOS等数据集上进行相应的实验。