在图像任务中,因果模型针对的最后任务可以是多种多样的,然而其生效的核心思想却都大同小异:在模型对图像进行高纬特征提取时,更加关注我们认为的与结果相对应的因果信息,而尽量忽视背景等与目标的虚假相关性。因果推断基本理论告诉我们,要想实现 P ( Y ∣ X ) = P ( Y ∣ d o ( X ) ) P(Y|X)=P(Y|do(X)) P(Y∣X)=P(Y∣do(X)),一个经典的方法就是对 Y Y Y和 X X X的虚假相关变量 Z Z Z进行干涉(将其条件限定),从而切断 X X X与 Z Z Z的backdoor。如何从输入图像中将 X X X和 Z Z Z进行分离(或者是二者的高维特征),并且在保存 X X X的情况下对 Z Z Z进行人为干涉,就成了每个因果模型的核心问题。
下面给出几篇文章中的因果推断实现方法:
一、Proactive Pseudo-Intervention: Causally Informed Contrastive Learning For Interpretable Vision Models
本文的基本思想是,分类器在进行分类时,应该关注于图像的因果信息
X
X
X,而去忽略图像的无关背景信息
Z
Z
Z。因此本文先是利用分类器构建了一个WBP过程(图中给出了以全连接和ReLU激活的分类器的WBP过程,其他网络和激活见原文的Table 1),其作用是通过对每个图像点正向传播时对分类结果的影响程度,来反向传播回图像上的对应位置,从而得到一张每个像素点的权重图。之后通过将原图中WBP权重高的地方进行人为的mask,来获得去掉了因果信息
X
X
X的图像。作者认为当一个图像完全去掉了因果信息,分类器就不能将其正确分类,而此时WBP也就不会显示高权重的区域。此外为了避免分类器对图像是否mask产生的本文的对比学习的思想,作者将mask掉WBP区域的图像称之为完整图像的negative样本,mask掉其他区域的称之为positive样本,分类器在训练时应该将完整图像和positive样本正确分类,而将negative样本错误分类。通过这种方法,分类器的WBP过程产生的saliency map将越来越集中到因果信息上。
本文的优化目标:
1、
其中
x
i
∗
x^*_i
xi∗是
x
i
x_i
xi图像对应的negative样本,
⊢
y
\vdash y
⊢y代表label
y
y
y的相反值,e.g. label 是二分类,
y
=
0
y=0
y=0,那么
⊢
y
=
1
\vdash y=1
⊢y=1,如果是多分类,使用one vs. others cross entropy loss。
f
θ
f_{\theta}
fθ是分类网络。
2、
其中
x
i
′
x'_i
xi′是
x
i
x_i
xi图像对应的positive样本。
3、
其中
s
m
s_m
sm代表图像
x
x
x对应的label =
m
m
m的WBP图。这个损失项的目的是对生成的WBP图中权重显示区域限制,从而只显示其中最重要的部分。
之后还有通过 x x x和 y y y对 f θ f_{\theta} fθ进行训练的普通分类损失,不细说。
于是本文的整体目标损失为:
二、Improving Weakly-supervised Object Localization via Causal Intervention
本文将图像在分类问题中受到背景信息
C
C
C纠缠从而使输入和目标
Y
Y
Y之间产生虚假相关性的问题用下列因果图表示:
其中
V
V
V代表因果信息
X
X
X和背景信息
C
C
C的共同结果,可以表示为
V
=
f
(
X
,
C
)
V=f(X,C)
V=f(X,C),其中
C
=
(
c
1
,
c
2
,
.
.
.
,
c
n
)
C=(c_1,c_2,...,c_n)
C=(c1,c2,...,cn)代表不同类图像的背景信息。本文的目的是获得
P
(
Y
∣
d
o
(
X
)
)
P(Y|do(X))
P(Y∣do(X)),用公式表示:
当我们把所有类的背景信息都作用到了
X
X
X上时,此时就相当于X不会对单独的某一类的背景信息产生虚假的相关性,那么我们就实现了上图中的(b)。我们依据Normalized Weighted Geometric Mean对上式进一步变形:
这使得我们可以将n步优化转换为1步优化。之后,如果我们进一步假设所有类的分布近乎相同,那么我们可以进一步将
P
(
c
i
)
P(c_i)
P(ci)替换为
1
/
n
1/n
1/n,则上式可以进一步化为:
其中
⊕
\oplus
⊕代表投影,(也就是后面我们要讲到的Casual context pool的操作)。因此,通过上面一系列分析,本文将问题转化为了求后面求和项的问题。而后面的求和项,就是本文的重点,上面流程图中的Casual context pool。
本文的流程如下:首先通过一个VGG网络将RGB图转换为位置相关的高维特征图
X
X
X(position-aware feature maps),其次将其进入到一个CAM模块中,针对目标,使用分类器产生一系列CAM图M和每个类的分类概率
S
=
{
s
1
,
s
2
,
.
.
.
,
s
n
}
S=\{s_1,s_2,...,s_n\}
S={s1,s2,...,sn},权重最高的类的CAM图经过BN后将会被送入到Causal Context pool中。Causal Context pool是一个在整个训练过程中不断的收集个各类的背景信息的字典,我们可以用
Q
∈
R
n
×
h
×
w
Q \in R^{n×h×w}
Q∈Rn×h×w来表示。在每次训练的过程中,Q将输入的CAM图的背景信息储存到对应的类的通道中,然后将之前储存的对应的类的所有背景信息映射到前面VGG所输出的高纬特征上,从而完成一个对应背景信息储存+对应类信息强化的过程。其信息储存过程可以用
Q
π
=
B
N
(
Q
π
+
λ
×
B
N
(
M
π
)
)
)
Q_π = BN(Q_π + λ × BN(M_π)))
Qπ=BN(Qπ+λ×BN(Mπ)))来表示,其中
π
\pi
π代表更新的类别。由此我们得到了被
Q
Q
Q强化后的高纬特征图
X
e
=
X
+
X
⊙
C
o
n
v
1
×
1
(
Q
π
)
X^e = X + X \odot Conv_{1×1} (Q_π )
Xe=X+X⊙Conv1×1(Qπ)。在这个强化版的特征图中我们在对背景信息进行了人为干扰的同时也强化了目标信息。然后我们再将强化后的特征图再次输入到CAM模块中,获得强化版的CAM图。这里需要说明一点,前后两次CAM模块的参数权重是共享的。于是整个模型通过维护一个Q,实现了上面分析式中的加和项,而没有引入其他训练参数。
最终我们的优化目标是:
其中
s
∗
s^*
s∗是真实label,
s
s
s是第一次CAM过程产生的label,
s
e
s^e
se是第二次CAM过程产生的label。
三、Style Normalization and Restitution for Domain Generalization and Adaptation
本文针对的问题是,在涉及到不同风格的图像的处理问题上,Instance Norm方法可以极大的去掉每张图像自身的域信息,或者理解为将
Z
Z
Z进行统一化,保留
X
X
X,从而方便后面的任务对图像object进行处理。然而这种方法在去掉图像域信息的同时,不可避免的会去掉一些object的信息,这就会对后面的任务产生不好的影响。因此这篇文章提出我们需要在不将图像域信息引入的情况下对object信息进行加强,或者说是将前面步骤中去掉的部分object信息再补充回来。
本文的基本思想是:我们从风格归一化过程中去掉的信息中提取部分信息添加到我们的风格归一化特征里,如果我们补充的信息是与任务有关的因果信息,那么将使得组合后的特征有着更好的可分性;而如果是与因果无关的风格信息,那么这些信息可以看成是干扰,也就是组合后的信息会变得更难分辨。
具体的方法见上图(b)部分。本文将现有网络流程中提取的图像特征
F
F
F进行IN,得到风格归一化的特征
F
~
=
I
N
(
F
)
\tilde{F}=IN(F)
F~=IN(F),之后我们获得二者之间的差异
R
=
F
−
F
~
R=F-\tilde{F}
R=F−F~(理想状态下这个差异仅仅代表每张图像自身的风格信息)。之后我们将
R
R
R进行通道attention,获得一个attention向量代表每个通道的权重,我们用
a
a
a表示,我们希望包含object信息的通道可以有着更高的权重,那么也就是在将权重
a
a
a添加到
R
R
R上后,得到的将是前面被IN去掉的部分object信息,用
R
+
R^+
R+表示。我们将
R
+
R^+
R+添加到
F
~
\tilde{F}
F~上后的结果
F
~
+
=
R
+
+
F
~
\tilde{F}^+=R^++\tilde{F}
F~+=R++F~相比于
F
~
\tilde{F}
F~会被更好的分类(有明确类信息)。同时为了进行相互对比,我们将权重
1
−
a
1-a
1−a添加到
R
R
R上后得到
F
~
−
=
R
−
+
F
~
\tilde{F}^-=R^-+\tilde{F}
F~−=R−+F~,相比于
F
~
\tilde{F}
F~将变得难以分类(无明确类信息)。此外在上图(c)部分涉及到如何将得到的
F
~
+
\tilde{F}^+
F~+、
F
~
−
\tilde{F}^-
F~−和
F
~
\tilde{F}
F~进行变换后输入到分类器(我们用
ϕ
\phi
ϕ表示)训练的问题。本文对于不同的下游任务所采用的变换方法是不同的:针对分类任务,直接使用spatial average pool将上面三个高维特征变成
1
×
1
×
C
1 \times 1 \times C
1×1×C的向量;针对分割任务,会对每一个像素对应的通道看成一个向量;对于检测任务,则是用检测框遍历特征的长宽维度,将得到的检测框大小的范围进行spatial average pool。最终我们得到输入到分类器
f
+
=
p
o
o
l
(
F
+
)
∈
R
c
f
−
=
p
o
o
l
(
F
−
)
∈
R
c
f
=
p
o
o
l
(
F
)
∈
R
c
f^+ = pool(F^+) \in R^c\\ f^-=pool(F^-) \in R^c\\ f=pool(F) \in R^c
f+=pool(F+)∈Rcf−=pool(F−)∈Rcf=pool(F)∈Rc
本文优化目标:
首先
S
o
f
t
p
l
u
s
(
⋅
)
=
l
n
(
1
+
e
x
p
(
⋅
)
)
Softplus(·) = ln(1 + exp(·))
Softplus(⋅)=ln(1+exp(⋅))使得优化目标保持非负。然后在看其中的
H
(
⋅
)
H(·)
H(⋅)函数,这个函数就是一个经典的交叉熵损失
H
(
⋅
)
=
−
p
(
⋅
)
l
o
g
p
(
⋅
)
H(·) = −p(·)logp(·)
H(⋅)=−p(⋅)logp(⋅),但是这里并不是
y
y
y与
p
(
.
)
p(.)
p(.)进行交叉熵,而是
p
(
.
)
p(.)
p(.)和自身进行交叉熵。通过这个损失,本文仅仅是希望和
f
~
\tilde{f}
f~相比,
f
~
+
\tilde{f}^+
f~+可以有较小的熵,而
f
~
−
\tilde{f}^-
f~−拥有较大的熵,而不需要去关心具体分类结果是哪一类,从这个角度来说,本文不仅可以适用于分类这种可以有具体label的问题,更可以适用于语义分割等难以给出具体label的问题。
这里还要说一下,因为本文并不是提出的一个针对具体问题的具体模型,因此本文没有给出完整网络流程和整个优化目标,但是人感觉本文的模块最大的优势就是即插即用,或者是可以将其看成是一种增强版的IN。而对于这个模块,仅有其中分类器
ϕ
\phi
ϕ的参数是需要学习的,那么我们完全可将其进行预训练好后固定参数,然后再进行目标网络的优化。此外如果网络中所有Conv Bock所输出的高纬特征
F
F
F的通道数是相同的,那么穿插其中的SNR块中分类器的参数也完全可以是共享的,这就大大的简化了网络所需要训练的参数量。