前言
如果你对这篇文章感兴趣,可以点击「【访客必读 - 指引页】一文囊括主页内所有高质量博客」,查看完整博客分类与对应链接。
概述
首先从一个流感的例子讲起,医院在八月根据当月数据训练了模型 f f f,假设其特征 x \bm{x} x 为「有无咳嗽」,预测标签 y y y 为「有无得流感」。
后续几个月模型 f f f 运转良好,但到第二年二月时,医院发现 f f f 预测为「得流感」的人数大幅增加,此时我们知道这与「冬季是流感高发期」有关。但一个问题随即出现了,用八月数据训出的 f f f 是否在二月也能有效预测,其在八月数据上学得的先验是否会影响二月时的判断。
将问题形式化,我们可以发现八月和二月的 p ( y ∣ x ) = p ( p(y\mid \bm{x})=p( p(y∣x)=p(流感 ∣ \mid ∣ 咳嗽 ) ) ) 和 p ( y ) = p ( p(y)=p( p(y)=p(流感 ) ) ) 明显发生了变化,因此过往在「covariate shift」上的研究不再适用。
继续深入,我们可以发现
p
(
x
∣
y
)
=
p
(
p(\bm{x}\mid y)=p(
p(x∣y)=p(咳嗽
∣
\mid
∣ 流感
)
)
) 似乎并没有发生太大的变化,由此引入本篇文章所关注的「label shift
」问题,其代表下述这种情况:
- 标签边际分布 p ( y ) p(y) p(y) 发生变化,但条件分布 p ( x ∣ y ) p(\bm{x}\mid y) p(x∣y) 不变
随后文中提出「Black Box Shift Estimation (BBSE)」方法,利用「黑盒预测器」来估计变化的 p ( y ) p(y) p(y),且仅要求其对应「混淆矩阵 (confusion matrices)」是可逆的,即使预测器是 biased,inaccurate 或 uncalibrated。
问题设定
源域
:
X
×
Y
\mathcal{X}\times \mathcal{Y}
X×Y 上的分布
P
P
P,
D
=
{
(
x
i
,
y
i
)
}
i
=
1
n
D=\{(\bm{x}_i, y_i)\}_{i=1}^n
D={(xi,yi)}i=1n,基于
D
D
D 训练得到的黑盒模型
f
:
X
→
Y
f:\mathcal{X}\rightarrow \mathcal{Y}
f:X→Y
目标域
:
X
×
Y
\mathcal{X}\times \mathcal{Y}
X×Y 上的分布
Q
Q
Q,
X
′
=
[
x
1
′
;
.
.
.
;
x
m
′
]
X'=[\bm{x}_1';...;\bm{x}_m']
X′=[x1′;...;xm′]
目标
:检测
P
→
Q
P\rightarrow Q
P→Q 是否发生了「label shift」,若发生了则重新训练模型,使其适应分布
Q
Q
Q
三大假设
:
- 「label shift / target shift」假设:
p ( x ∣ y ) = q ( x ∣ y ) ∀ x ∈ X , y ∈ Y p(\boldsymbol{x} \mid y)=q(\boldsymbol{x} \mid y) \quad \forall x \in \mathcal{X}, y \in \mathcal{Y} p(x∣y)=q(x∣y)∀x∈X,y∈Y - ∀ y ∈ Y \forall y\in \mathcal{Y} ∀y∈Y,若 q ( y ) > 0 q(y)>0 q(y)>0 则 p ( y ) > 0 p(y)>0 p(y)>0
-
f
f
f 对应的混淆矩阵 (confusion matrix)
C
p
(
f
)
\mathrm{C}_p(f)
Cp(f) 可逆,矩阵定义如下:
C P ( f ) : = p ( f ( x ) , y ) ∈ R ∣ Y ∣ × ∣ Y ∣ \mathbf{C}_P(f):=p(f(x), y) \in \mathbb{R}^{|\mathcal{Y}| \times|\mathcal{Y}|} CP(f):=p(f(x),y)∈R∣Y∣×∣Y∣
BBSE
「Black Box Shift Estimation (BBSE)」方法主要用于估计
w
(
y
)
=
q
(
y
)
/
p
(
y
)
w(y)=q(y)/p(y)
w(y)=q(y)/p(y),其核心思路如下:
q
(
y
^
)
=
∑
y
∈
Y
q
(
y
^
∣
y
)
q
(
y
)
=
∑
y
∈
Y
p
(
y
^
∣
y
)
q
(
y
)
=
∑
y
∈
Y
p
(
y
^
,
y
)
q
(
y
)
p
(
y
)
\begin{aligned} q(\hat{y}) &=\sum_{y \in \mathcal{Y}} q(\hat{y} \mid y) q(y) \\ &=\sum_{y \in \mathcal{Y}} p(\hat{y} \mid y) q(y)=\sum_{y \in \mathcal{Y}} p(\hat{y}, y) \frac{q(y)}{p(y)} \end{aligned}
q(y^)=y∈Y∑q(y^∣y)q(y)=y∈Y∑p(y^∣y)q(y)=y∈Y∑p(y^,y)p(y)q(y)
其中
y
^
\hat{y}
y^ 即
f
f
f 给出的伪标记,而
q
(
y
^
∣
y
)
=
p
(
y
^
∣
y
)
q(\hat{y}\mid y)=p(\hat{y}\mid y)
q(y^∣y)=p(y^∣y) 则来自于下述推导:
q
(
y
^
∣
y
)
=
∑
x
∈
X
q
(
y
^
∣
x
,
y
)
q
(
x
∣
y
)
=
∑
x
∈
X
q
(
y
^
∣
x
,
y
)
p
(
x
∣
y
)
=
∑
x
∈
X
p
f
(
y
^
∣
x
)
p
(
x
∣
y
)
=
∑
x
∈
X
p
(
y
^
∣
x
,
y
)
p
(
x
∣
y
)
=
p
(
y
^
∣
y
)
\begin{aligned} &q(\hat{y} \mid y)=\sum_{\boldsymbol{x} \in \mathcal{X}} q(\hat{y} \mid \boldsymbol{x}, y) q(\boldsymbol{x} \mid y)=\sum_{\boldsymbol{x} \in \mathcal{X}} q(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y) \\ &=\sum_{\boldsymbol{x} \in \mathcal{X}} p_f(\hat{y} \mid \boldsymbol{x}) p(\boldsymbol{x} \mid y)=\sum_{\boldsymbol{x} \in \mathcal{X}} p(\hat{y} \mid \boldsymbol{x}, y) p(\boldsymbol{x} \mid y)=p(\hat{y} \mid y) \end{aligned}
q(y^∣y)=x∈X∑q(y^∣x,y)q(x∣y)=x∈X∑q(y^∣x,y)p(x∣y)=x∈X∑pf(y^∣x)p(x∣y)=x∈X∑p(y^∣x,y)p(x∣y)=p(y^∣y)
其关键部分在于
q
(
x
∣
y
)
=
p
(
x
∣
y
)
q(\bm{x}\mid y)=p(\bm{x}\mid y)
q(x∣y)=p(x∣y) 的假设以及
y
^
⊥
⊥
y
∣
x
\hat{y} \perp \!\!\! \perp y \mid \boldsymbol{x}
y^⊥⊥y∣x 的条件独立性。随后便可以得到:
μ
y
^
=
C
y
^
∣
y
μ
y
=
C
y
^
,
y
w
w
^
=
C
^
y
^
,
y
−
1
μ
^
y
^
μ
^
y
=
diag
(
ν
^
y
)
w
^
\begin{gathered} \mu_{\hat{y}}=\mathrm{C}_{\hat{y} \mid y} \mu_y=\mathrm{C}_{\hat{y}, y} w \\ \hat{\boldsymbol{w}}=\hat{\mathbf{C}}_{\hat{y}, y}^{-1} \hat{\boldsymbol{\mu}}_{\hat{y}} \\ \hat{\boldsymbol{\mu}}_y=\operatorname{diag}\left(\hat{\boldsymbol{\nu}}_y\right) \hat{\boldsymbol{w}} \end{gathered}
μy^=Cy^∣yμy=Cy^,yww^=C^y^,y−1μ^y^μ^y=diag(ν^y)w^
其中各符号定义如下,其核心思想就是本节最开头的公式,只不过为了严谨而引入了大量符号,但实质不变。
理论保障
首先是「Consistency」的保证:
其次是「Error bounds」方面的保证:
根据上述「Error bounds」的结果,可以发现在选择黑盒模型时,「
C
y
^
,
y
\mathrm{C}_{\hat{y}, y}
Cy^,y 最小奇异值」越大的模型越合适。
Label-Shift 检测
在最开头的三大假设下, q ( y ) = p ( y ) ⇔ p ( y ^ ) = q ( y ^ ) q(y)=p(y)\Leftrightarrow p(\hat{y})=q(\hat{y}) q(y)=p(y)⇔p(y^)=q(y^),因此使用「two-sample tests」对 p ( y ^ ) = q ( y ^ ) p(\hat{y})=q(\hat{y}) p(y^)=q(y^) 进行检测即可。
让模型适应新分布
计算出
w
^
\hat{\bm{w}}
w^ 后,采用「importance weighted ERM」在源域数据集
D
\mathcal{D}
D 上重新训练模型即可,具体训练目标如下:
L
=
∑
i
=
1
n
w
^
i
⋅
ℓ
(
y
i
,
x
i
)
\mathcal{L}=\sum_{i=1}^n \hat{w}_i\cdot \ell\left(y_i, \bm{x}_i\right)
L=i=1∑nw^i⋅ℓ(yi,xi)
整体算法如下:
检测 Label-Shift 假设成立
采用「kernel two-sample tests」检测下述式子是否成立:
E
p
[
w
(
y
)
k
(
ϕ
(
x
)
,
⋅
)
]
=
E
q
[
k
(
ϕ
(
x
)
,
⋅
)
]
\mathbb{E}_p[\boldsymbol{w}(y) k(\phi(\boldsymbol{x}), \cdot)]=\mathbb{E}_q[k(\phi(\boldsymbol{x}), \cdot)]
Ep[w(y)k(ϕ(x),⋅)]=Eq[k(ϕ(x),⋅)]
即转化为下述 MMD 距离的计算:
∥
1
n
∑
i
=
1
n
[
w
^
(
y
i
)
k
(
ϕ
(
x
i
)
,
⋅
)
]
−
1
m
∑
j
=
1
m
k
(
ϕ
(
x
j
′
)
,
⋅
)
∥
H
2
\left\|\frac{1}{n} \sum_{i=1}^n\left[\hat{\boldsymbol{w}}\left(y_i\right) k\left(\phi\left(\boldsymbol{x}_i\right), \cdot\right)\right]-\frac{1}{m} \sum_{j=1}^m k\left(\phi\left(\boldsymbol{x}_j^{\prime}\right), \cdot\right)\right\|_{\mathcal{H}}^2
n1i=1∑n[w^(yi)k(ϕ(xi),⋅)]−m1j=1∑mk(ϕ(xj′),⋅)
H2