解决参考图像分割中的随机性问题:MMNet: Multi-Mask Network for Referring Image Segmentation 论文阅读笔记
写在前面
今天六一儿童节,祝大家节日快乐吖!
这是一篇关于 RIS 参考图像分割的论文,指出 RIS 中的语言和图像存在随机性的问题,观点比较新,虽并未发表在顶会,其方法也不是特别出彩,但论文的立意还是蛮好的。
- 论文地址:https://arxiv.org/abs/2305.14969
- 代码地址:无
- 收录于:加拿大渥太华召开的某个会议
- PS:2023 每周一篇博文,主页更多干货,欢迎关注呀,期待 4 千粉丝有你呦~
一、Abstract
首先指出 Referring image segmentation(RIS)的定义,表明难点在于目标的类别的多样性以及表达式的无约束性。之前的方法主要关注于跨模态的特征对齐而未能解决这种固有的随机性问题。本文提出基于 CLIP 的 Multi-Mask Network(MMNet):首先联合图像和语言,利用注意力机制生成能够表示不同侧重点的多个语言表达式 queries;然后利用这些 queries 生成一系列相应的 masks,并依据重要性程度打分;最后对所有 masks 进行权重求和得到最终的结果。在 RefCOCO、RefCOCO+、G-Ref 上表现很好。
二、引言
RIS 的应用和挑战:图文之间存在明显的数据差异,很难有效对齐;广泛的目标类别和未受限的语言表达式,高度随机性。
早期的方法关注于图像和文本的融合。后来随着注意力机制兴起,大量的方法用来学习跨模态特征。再之后,大规模预训练模型上场。但这些都没能完全解决图文随机性问题。
如上图所示,随机性主要源于两个方面:(1)语句本身的混乱,即一个单词在不同的语境中可以表示不同的含义;(2)强调的侧重点不同。于是需要结合图片来解决随机性问题。
本文通过生成一些列的分割 masks,最后结合这些 masks 来得到最终的结果。
如上图所示,首先基于语言表达式来生成多个 queries。不同于 VLT,本文为每个 query 生成一个相应的 mask,最后通过整合所有的 masks 得到最终的结果。此外,利用 CLLP 模型来提取丰富的视觉-语言知识。本文贡献总结如下:
- 提出 Multi-Mask Network(MMNet) 来生成多个 mask,并利用这些 mask 得到最终的结果,从而解决随机性的问题;
- 充分利用 CLIP 模型提取细粒度的和全局的视觉信息从而提高性能;
- 在 RefCOCO、RefCOCO、G-Ref 上表现很好。
三、相关工作
Referring Image Segmentation
早期的文献首先通过 CNN 和 LSTM 分别提取视觉和语言特征,然后直接拼接进行分类,得到最终的分割结果。一些方法设计出能够同时处理 RIS 和 REC 的网络。之后就是注意力机制的兴起:BRINet、LAVT、VLT、CRIS。然而这些方法主要关注于如何提升特征融合的有效性,但未能解决大量目标以及语言表达式不受限造成的随机性问题。
Vision-Language Pretraining
视觉-语言预训练旨在学习视觉和文本信息的共同表示。CLIP 是其中的里程碑式工作。CRIS 旨在迁移图像级的视觉概念到 RIS 上。然而 CRIS 仅关注于细粒度的视觉表示而忽视了全局视觉信息,恰好这是 CLIP 所擅长的。相比之下,本文提出的方法也利用了 CLIP 模型,但同时关注于细粒度的和全局的视觉信息。
四、方法
首先采用图像和语言表达式作为输入,利用 ResNet/ViT 和一个 Transformer 来提取图像和文本特征以及它们的全局特征。之后全局文本特征和视觉特征进行融合来得到简易的多模态特征。然后在 Multi-Query Generator 中利用全局视觉特征,patch 特征,文本特征来产生多个 queries。接着产生的多个 queries 和多模态特征一起送入到视觉-语言解码器中。解码器的输出和生成的 queries 作为 Multi-Mask Projector 的输入来产生多个 masks。同时 Multi-Query Estimator 利用生成的 queries 来决定每个 mask 的权重。最终使用这些 masks 及相应的权重来进行权重求和,从而得到最终的预测结果。
4.1 图像文本特征提取
文本编码器
给定一个语言表达式 T ∈ R L T\in \mathbb R^L T∈RL,利用一个 Transformer 得到文本特征 F t ∈ R L × C F_t\in \mathbb R^{L\times C} Ft∈RL×C。接下来沿着 CLIP 的方法,使用字节对编码 [SOS] 开始这段序列,用 [EOS] 表示序列的结束。类似于 CRIS,采用Transformer 最高激活层的 [EOS] token 作为整个表达式的全局特征。这一特征之后用全连接层转化为 F t g ∈ R C ′ F_{tg}\in\mathbb R^{C^{\prime}} Ftg∈RC′,其中 C C C、 C ′ C^\prime C′ 为特征维度, L L L 为语言表达式的长度。
图像编码器
给定图像
I
∈
R
H
×
W
×
3
I\in\mathbb R^{H\times W\times3}
I∈RH×W×3,利用 ResNet 来提取第二和第三阶段的特征
X
2
∈
R
H
2
×
W
2
×
C
X_2\in\mathbb R^{H_2\times W_2\times C}
X2∈RH2×W2×C,
X
3
∈
R
H
3
×
W
3
×
C
X_3\in\mathbb R^{H_3\times W_3\times C}
X3∈RH3×W3×C。利用线性全连接层改变其通道数
X
2
∈
R
H
2
×
W
2
×
C
2
X_2\in\mathbb R^{H_2\times W_2\times C_2}
X2∈RH2×W2×C2,
X
3
∈
R
H
3
×
W
3
×
C
3
X_3\in\mathbb R^{H_3\times W_3\times C_3}
X3∈RH3×W3×C3。在第四个阶段,除常规特征
X
4
∈
R
H
4
×
W
4
×
C
4
X_4\in\mathbb R^{H_4\times W_4\times C_4}
X4∈RH4×W4×C4 外,还利用全局平均池化得到全局特征
X
‾
4
∈
R
C
\overline X_4 \in\mathbb R^C
X4∈RC。之后拼接特征
[
X
‾
4
,
X
4
]
[\overline X_4,X_4]
[X4,X4],并将其送入到多头自注意力层:
[
z
‾
,
z
]
=
M
H
S
A
(
[
x
‾
4
,
x
4
]
)
[\overline{\text{z}},\text{z}]=M H S A([\overline{\text{x}}_4,\text{x}_4])
[z,z]=MHSA([x4,x4])之后利用一个全连接层将
z
\text{z}
z、
x
‾
\overline{\text{x}}
x 分别转化为
F
v
4
∈
R
H
4
×
W
4
×
C
4
F_{v4}\in\mathbb{R}^{H_{4}\times W_{4}\times C_{4}}
Fv4∈RH4×W4×C4 和
F
v
g
∈
R
C
4
F_{vg}\in\mathbb{R}^{C_{4}}
Fvg∈RC4。
Fusion Neck
在融合模块中,使用下列式子融合
F
v
4
F_{v4}
Fv4 和
F
v
g
F_{vg}
Fvg,得到
F
m
4
∈
R
H
3
×
W
3
×
C
F_{m4}\in\mathbb R^{H_3\times W_3\times C}
Fm4∈RH3×W3×C:
F
m
4
=
U
p
(
σ
(
F
v
4
W
v
4
)
⋅
σ
(
F
t
g
W
t
g
)
)
F_{m4}=Up\left(\sigma\left(F_{v4}W_{v4}\right)\cdot\sigma\left(F_{tg}W_{tg}\right)\right)
Fm4=Up(σ(Fv4Wv4)⋅σ(FtgWtg))其中
U
p
(
⋅
)
Up(\cdot)
Up(⋅) 表示 2 倍上采样,
⋅
\cdot
⋅ 为逐元素点乘操作,
W
v
4
W_{v4}
Wv4 和
W
t
g
W_{tg}
Wtg 为 全连接层的权重,
σ
\sigma
σ 为 ReLU 激活函数。同样使用相同的步骤得到
F
m
3
F_{m3}
Fm3 和
F
m
2
F_{m2}
Fm2:
F
m
3
=
[
σ
(
F
m
4
W
m
4
)
,
σ
(
F
v
3
W
v
3
)
]
F
m
2
=
[
σ
(
F
m
3
W
m
3
)
,
σ
(
F
v
2
′
W
v
2
)
]
,
F
v
2
′
=
A
v
g
(
F
v
2
)
\begin{aligned} &F_{m_{3}} =\left[\sigma\left(F_{m_4}W_{m_4}\right),\sigma\left(F_{v_3}W_{v_3}\right)\right] \\ &F_{m_{2}} =\left[\sigma\left(F_{m_{3}}W_{m_{3}}\right),\sigma\left(F_{v_{2}}^{\prime}W_{v_{2}}\right)\right],F_{v_{2}}^{\prime}=A v g\left(F_{v_{2}}\right) \end{aligned}
Fm3=[σ(Fm4Wm4),σ(Fv3Wv3)]Fm2=[σ(Fm3Wm3),σ(Fv2′Wv2)],Fv2′=Avg(Fv2)其中
A
v
g
(
⋅
)
A v g(\cdot)
Avg(⋅) 为
2
×
2
2\times2
2×2 的平均池化操作,
[
,
]
[,]
[,] 为拼接操作。接下来,拼接多模态特征
(
F
m
4
,
F
m
3
,
F
m
2
)
(F_{m_{4}},F_{m_{3}},F_{m_{2}})
(Fm4,Fm3,Fm2),并用 1 个
1
×
1
1\times1
1×1 卷积层来聚合:
F
m
=
C
o
n
o
(
[
F
m
2
,
F
m
3
,
F
m
4
]
)
F_m=Cono\left(\left[F_{m_2},F_{m_3},F_{m_4}\right]\right)
Fm=Cono([Fm2,Fm3,Fm4])其中
F
m
∈
R
H
3
×
W
3
×
C
F_m\in\mathbb R^{H_3\times W_3\times C}
Fm∈RH3×W3×C,得其 2D 坐标
F
c
o
o
r
d
∈
R
H
3
×
W
3
×
2
F_{coord}\in\mathbb R^{H_3\times W_3\times 2}
Fcoord∈RH3×W3×2,与
F
m
F_{m}
Fm 拼接并展平得到融合全局文本信息的视觉特征
F
v
t
∈
R
N
×
C
F_{vt}\in\mathbb R^{N\times C}
Fvt∈RN×C,
N
=
H
3
×
W
3
=
H
16
×
W
16
N=H_3\times W_3=\frac{H}{16}\times\frac{W}{16}
N=H3×W3=16H×16W。与 ViT 类似,直接提取类别 token 作为全局视觉特征,然后分别使用三个卷积层得到三个特征,其特征通道维度与
F
v
2
F_{v2}
Fv2、
F
v
3
F_{v3}
Fv3、
F
v
4
F_{v4}
Fv4 相同。
4.2 Multi-Query Generator
Multi-Query Generator 采用多阶段视觉特征 { F v i } i = 2 4 \{F_{vi}\}^4_{i=2} {Fvi}i=24、全局视觉特征 F v g F_{vg} Fvg、文本特征 F t F_t Ft 作为输入,输出一系列的 queries。
Dense Visual Features
通过下列步骤获得稠密的视觉特征:
F
m
4
′
=
U
p
(
σ
(
F
v
4
W
a
4
′
)
)
F
m
3
′
=
[
σ
(
F
m
4
′
W
m
4
′
)
,
σ
(
F
v
3
W
v
3
′
)
]
F
m
2
′
=
[
σ
(
F
m
3
′
W
m
3
′
)
,
σ
(
F
v
2
′
W
v
2
′
)
]
,
F
v
2
′
′
=
A
o
g
(
F
v
2
)
F
m
′
=
C
o
n
v
(
[
F
m
2
′
,
F
m
3
′
,
F
m
4
′
]
)
,
F
v
′
=
C
o
n
v
(
[
F
m
′
,
F
c
o
o
r
d
]
)
\begin{aligned} &F_{m_4}'=Up\left(\sigma\left(F_{v4}W_{a4}'\right)\right) \\ & \\ &F_{m_{3}}^{\prime}=\left[\sigma\left(F_{m_{4}}^{\prime}W_{m_{4}}^{\prime}\right),\sigma\left(F_{v_{3}}W_{v_{3}}^{\prime}\right)\right] \\ &F_{m_{2}}^{\prime}=\left[\sigma\left(F_{m_{3}}^{\prime}W_{m_{3}}^{\prime}\right),\sigma\left(F_{v_{2}}^{\prime}W_{v_{2}}^{\prime}\right)\right],F_{v_{2}}^{\prime\prime}=A o g\left(F_{v_{2}}\right) \\ &F_{m}^{\prime}=C o n v\left(\left[F_{m2}^{\prime},F_{m3}^{\prime},F_{m4}^{\prime}\right]\right),F_{v}^{\prime}=C o n v\left(\left[F_{m}^{\prime},F_{c o o r d}\right]\right) \end{aligned}
Fm4′=Up(σ(Fv4Wa4′))Fm3′=[σ(Fm4′Wm4′),σ(Fv3Wv3′)]Fm2′=[σ(Fm3′Wm3′),σ(Fv2′Wv2′)],Fv2′′=Aog(Fv2)Fm′=Conv([Fm2′,Fm3′,Fm4′]),Fv′=Conv([Fm′,Fcoord])与 Fusion Neck 不同之处在于
F
v
d
F_{vd}
Fvd 并未整合全局文本信息。
接下来应用三个卷积层来减少其特征通道维度到
N
q
N_q
Nq,之后展平宽度和高度:
F
v
d
=
f
l
a
t
t
e
n
(
C
o
n
v
(
F
v
′
)
)
T
F_{vd}=flatten\left(Conv\left(F_v'\right)\right)^T
Fvd=flatten(Conv(Fv′))T于是
F
v
d
∈
R
N
q
×
H
3
W
3
F_{vd}\in\mathbb R^{N_q\times H_3W_3}
Fvd∈RNq×H3W3。
Fused textual features
利用下列等式融合文本特征和全局视觉特征:
F
t
v
=
σ
(
F
t
W
t
)
⋅
σ
(
F
v
g
W
v
g
)
F_{tv}=\sigma\left(F_tW_t\right)\cdot\sigma\left(F_{vg}W_{vg}\right)
Ftv=σ(FtWt)⋅σ(FvgWvg)其中
F
t
v
∈
R
L
×
C
F_{tv}\in\mathbb R^{L\times C}
Ftv∈RL×C,
W
t
W_t
Wt、
W
v
g
W_{vg}
Wvg 为可学习的矩阵。
Multi-Query Generation
首先在
F
v
d
F_{vd}
Fvd 和
F
t
v
F_{tv}
Ftv 上应用线性投影,然后对第
n
n
n 个 query (
n
=
1
,
2
,
…
,
N
q
n=1,2,\dots,N_q
n=1,2,…,Nq),第
n
n
n 个稠密的视觉特征向量
f
v
d
n
∈
R
1
×
(
H
3
W
3
)
f_{vdn}\in\mathbb R^{1\times(H_3W_3)}
fvdn∈R1×(H3W3) 以及第
i
i
i (
i
=
1
,
2
,
…
,
L
i=1,2,\dots,L
i=1,2,…,L) 个单词的文本特征
f
t
v
i
∈
R
1
×
C
f_{tvi}\in\mathbb R^{1\times C}
ftvi∈R1×C,通过计算
f
v
d
n
f_{vdn}
fvdn 和
f
t
v
i
f_{tvi}
ftvi 的点乘投影得到第
i
i
i 个单词与
n
n
n 个 query 的注意力权重:
a
n
i
=
σ
(
f
v
d
n
W
o
d
)
σ
(
f
t
o
t
W
a
)
T
a_{ni}=\sigma\left(f_{vdn}W_{od}\right)\sigma\left(f_{tot}W_{a}\right)^{T}
ani=σ(fvdnWod)σ(ftotWa)T其中
a
n
i
a_{ni}
ani 表示衡量 第
i
i
i 个单词与
n
n
n 个 query 的重要性的标量,
W
o
d
W_{od}
Wod、
W
a
W_{a}
Wa 为可学习的矩阵。之后采用 Softmax 处理
a
n
i
a_{ni}
ani,从而形成注意力图
A
∈
R
N
q
×
L
A\in\mathbb R^{N_q\times L}
A∈RNq×L。对于第
n
n
n 个 query,对应的重要性程度为
A
n
∈
R
1
×
L
A_n\in\mathbb R_{1\times L}
An∈R1×L (
n
=
1
,
2
,
…
,
N
q
n=1,2,\dots,N_q
n=1,2,…,Nq),而
A
n
A_n
An 则用于生成新的 queries:
F
q
n
=
A
n
σ
(
F
t
v
W
t
v
)
F_{qn}=A_n\sigma\left(F_{tv}W_{tv}\right)
Fqn=Anσ(FtvWtv) 其中
W
t
v
W_{tv}
Wtv 为可学习的参数。所有的 queries 组成新的语言矩阵
F
q
∈
R
N
q
×
C
F_q\in\mathbb R^{N_q\times C}
Fq∈RNq×C,作为视觉-语言解码器的输入。
4.3 视觉-语言解码器
将 query 向量
F
q
F_q
Fq 和融合的视觉特征
F
v
t
F_{vt}
Fvt 作为输入,并加上空间位置信息。解码器结构采用标准的 Transformer,流程如下:
F
v
t
′
=
M
H
S
A
(
L
N
(
F
v
t
)
)
+
F
u
t
′
F_{vt}'=MHSA\left(LN\left(F_{vt}\right)\right)+F_{ut}'
Fvt′=MHSA(LN(Fvt))+Fut′其中
F
v
t
′
F_{vt}'
Fvt′ 为视觉特征,
M
H
S
A
(
⋅
)
MHSA(\cdot)
MHSA(⋅) 和
L
N
(
⋅
)
LN(\cdot)
LN(⋅) 表示多头自注意力层和归一化层。
M
H
S
A
(
Q
,
K
,
V
)
=
s
o
f
t
m
a
x
(
Q
K
T
d
k
)
MHSA\left(Q,K,V\right)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)
MHSA(Q,K,V)=softmax(dkQKT)其中
Q
∈
R
N
×
d
q
Q\in\mathbb R^{N\times d_q}
Q∈RN×dq 为
F
v
t
′
F_{vt}'
Fvt′ 的投影、
K
∈
R
N
×
d
k
K\in\mathbb R^{N\times d_k}
K∈RN×dk 和
V
∈
R
N
×
d
v
V\in\mathbb R^{N\times d_v}
V∈RN×dv 是
F
q
F_q
Fq 的投影。后续处理:
F
s
′
=
M
H
C
A
(
L
N
(
F
v
t
′
)
,
F
q
)
+
F
v
t
′
F
s
=
M
L
P
(
L
N
(
F
S
′
)
)
+
F
S
′
\begin{aligned} &F_{s}^{\prime} =MHCA\left(LN\left(F'_{vt}\right),F_q\right)+F'_{vt} \\ &F_{s} =MLP\left(LN\left(F_S'\right)\right)+F_S' \end{aligned}
Fs′=MHCA(LN(Fvt′),Fq)+Fvt′Fs=MLP(LN(FS′))+FS′其中
M
H
C
A
(
⋅
)
MHCA(\cdot)
MHCA(⋅) 为多头交叉注意力层,
F
S
′
F_S'
FS′ 为中间特征,多模态特征
F
s
F_s
Fs 用于产生最后的分割 mask。
4.4 Mask 解码器
Multi-Mask Projector
Multi-Mask Projector 以多模态特征
F
s
F_s
Fs 和 query 向量
F
q
F_q
Fq 为输入。从
F
q
F_q
Fq 中提取
F
q
n
F_{qn}
Fqn,在
F
s
F_s
Fs 的作用下生成 mask。接下来采用动态卷积生成
F
q
n
F_{qn}
Fqn:
F
p
=
U
p
(
C
o
n
v
(
U
p
(
F
s
)
)
)
F
p
n
=
σ
(
W
p
F
q
n
)
\begin{aligned} &F_{p}=U p(C o n v(U p(F_{s}))) \\ &F_{p n}=\sigma(W_{p}F_{q n}) \end{aligned}
Fp=Up(Conv(Up(Fs)))Fpn=σ(WpFqn)其中
F
s
F_s
Fs 上采样和卷积到
F
p
∈
R
4
H
3
×
4
W
3
×
C
p
F_p\in\mathbb R^{4H_3\times4W_3\times C_p}
Fp∈R4H3×4W3×Cp,
C
p
=
C
2
C_p=\frac{C}{2}
Cp=2C。之后利用线性层将
F
q
n
F_{qn}
Fqn 变为
F
p
n
∈
R
9
C
p
+
1
F_{pn}\in\mathbb R^{9C_p+1}
Fpn∈R9Cp+1。采用向量
F
p
n
F_{pn}
Fpn 中第一个
9
C
p
9C_p
9Cp 值作为
3
×
3
3\times3
3×3 卷积核的参数,通道数量为
C
p
C_p
Cp,
F
p
n
F_{pn}
Fpn 的最后一个值为偏置。之后利用卷积从第
n
n
n 个query
F
q
n
F_{qn}
Fqn 中得到 mask,表示为
m
a
s
k
n
∈
R
4
H
3
×
4
W
3
×
1
mask_n\in\mathbb R^{4H_3\times4W_3\times1}
maskn∈R4H3×4W3×1。
Multi-Query Estimator
Multi-Query Estimator 采用 query 向量
F
q
F_q
Fq 作为输入,输出
N
q
N_q
Nq 得分,每个得分表明 query
F
q
n
F_{qn}
Fqn 拟合预测上下文的程度,并控制相应的
m
a
s
k
n
mask_n
maskn。用公式表示如下:
S
q
=
S
o
f
t
m
a
x
(
W
s
(
M
H
S
A
(
F
q
)
)
)
S_q=Softmax(W_s(MHSA(F_q)))
Sq=Softmax(Ws(MHSA(Fq)))其中
S
q
∈
R
N
q
×
1
S_q\in\mathbb R^{N_q\times1}
Sq∈RNq×1。于是最终的预测结果为 Multi-Mask Generator 输出的 mask 与 Multi-Query Estimator 输出的得分进行权重求和:
y
=
∑
n
=
1
N
q
S
q
n
m
a
s
k
n
y=\sum_{n=1}^{N_q}S_{qn}mask_n
y=n=1∑NqSqnmaskn其中
S
q
n
S_{qn}
Sqn 为
S
q
S_q
Sq 的第
n
n
n 个标量,
y
y
y 表示最终的预测 mask。模型采用 cross-entropy 损失进行优化。
五、实验
5.1 实施细节
实验设置
ResNet-101、ViT 作为图像编码器,输入图像尺寸
480
×
480
480\times480
480×480,RefCOCO 和
RefCOCO+ 设置句子长度为 17,G-Ref 22。每个 Transformer 块有 8 个头,隐藏层维度 512,前向传播维度 2048。100 epochs,Adam 优化器,初始学习率
l
r
=
1
e
−
5
lr=1e-5
lr=1e−5,多项式衰减策略。batch 64,8 块 3090。
指标:IoU、Precision@X。
5.2 数据集
RefCOCO、RefCOCO+、G-Ref。
5.3 与其他方法的比较
5.4 消融实验
在 RefCOCO+ 的 testA 分布上进行实验,ResNet-50 为视觉 backbone,epoch 50。
Query 的数量
Multi-Mask Projector
Multi-Query Estimator
Global visual feature
同上表 4。
六、可视化
七、结论
本文提出一种端到端的基于 CLIP 模型的 Multi-Mask Network (MMNet) 方法,有效解决了目标类别差异以及未受限语言造成的随机性问题。在 RefCOCO、RefCOCO+、G-Ref 数据集上表现很好。
写在后面
通篇读下来,这篇论文还是略显啰嗦一点了,感觉有点水字数的嫌疑(想灌灌水的可以参考下哈~)。另外,这个随机性真的解决了吗?保留怀疑态度,就是创新点和解决的问题有点对不上的感觉。