腐败矩阵
利用腐败矩阵可以提高模型对含噪标签的鲁棒性。
具体应用如下:
一个带标签数据集中,有一部分信任的数据集
D
D
D,有部分是不信任的数据集
D
~
\widetilde D
D
,我们的目的是如何用上面标签含噪的数据集训练一个对标签噪声具有鲁棒性的网络。
第一步:估计腐败矩阵
腐败矩阵:就是腐败概率的KxK矩阵,K是类别数。腐败概率就是原本标签是 i,却分类成 j 的概率
C
i
j
=
p
(
y
~
=
j
∣
y
=
i
)
C_{ij} = p(\widetilde{y} =j|y=i)
Cij=p(y
=j∣y=i)
估计准备:
(1):
p
(
y
~
∣
x
,
y
)
∗
p
(
x
∣
y
)
=
p
(
y
~
∣
y
)
∗
p
(
x
∣
y
~
,
y
)
p(\widetilde{y}|x,y)*p(x|y)=p(\widetilde{y}|y)*p(x|\widetilde{y},y)
p(y
∣x,y)∗p(x∣y)=p(y
∣y)∗p(x∣y
,y)
将(1)式两边同时对x进行积分:
(2):
∫
p
(
y
~
∣
y
,
x
)
p
(
x
∣
y
)
d
x
=
p
(
y
~
∣
y
)
∫
p
(
x
∣
y
~
,
y
)
d
x
=
p
(
y
~
∣
y
)
\int p(\widetilde{y} \mid y, x) p(x \mid y) \mathrm{d} x=p(\widetilde{y} \mid y) \int p(x \mid \widetilde{y}, y) \mathrm{d} x=p(\widetilde{y} \mid y)
∫p(y
∣y,x)p(x∣y)dx=p(y
∣y)∫p(x∣y
,y)dx=p(y
∣y)
假设y和
y
^
\hat{y}
y^在给定x的情况下条件独立,,则有:
p
(
y
~
∣
y
,
x
)
=
p
(
y
~
∣
x
)
p(\widetilde{y} \mid y, x)=p(\widetilde{y} \mid x)
p(y
∣y,x)=p(y
∣x)
所以(2)式左边就可以看成
p
(
y
~
∣
x
)
p(\widetilde{y} \mid x)
p(y
∣x)的均值。
那
p
(
y
~
∣
x
)
p(\widetilde{y} \mid x)
p(y
∣x)怎么得到呢?
我们可以通过在不信任数据集
D
~
\widetilde D
D
上训练分类器得到
p
(
y
~
∣
x
)
p(\widetilde{y} \mid x)
p(y
∣x)的近似估计:
(3):
p
^
(
y
~
∣
y
,
x
)
≈
p
(
y
~
∣
y
,
x
)
\hat{p}(\widetilde{y} \mid y, x) \approx p(\widetilde{y} \mid y, x)
p^(y
∣y,x)≈p(y
∣y,x)
然后我们将在不信任数据集
D
~
\widetilde D
D
上训练分类器得到的分类器
p
^
(
y
~
∣
x
)
\hat{p}(\widetilde{y} \mid x)
p^(y
∣x)作用在信任的数据集上面,就可以得到腐败矩阵的估计值了:
(4):
C
^
i
j
=
1
∣
A
i
∣
∑
x
∈
A
i
p
^
(
y
~
=
j
∣
x
)
=
1
∣
A
i
∣
∑
x
∈
A
i
p
^
(
y
~
=
j
∣
y
=
i
,
x
)
≈
p
(
y
~
=
j
∣
y
=
i
)
.
\widehat{C}_{i j}=\frac{1}{\left|A_{i}\right|} \sum_{x \in A_{i}} \widehat{p}(\widetilde{y}=j \mid x)=\frac{1}{\left|A_{i}\right|} \sum_{x \in A_{i}} \widehat{p}(\widetilde{y}=j \mid y=i, x) \approx p(\widetilde{y}=j \mid y=i) .
C
ij=∣Ai∣1x∈Ai∑p
(y
=j∣x)=∣Ai∣1x∈Ai∑p
(y
=j∣y=i,x)≈p(y
=j∣y=i).
(4)式中
A
i
A_i
Ai表示信任数据集中标签为 i 的子集。
第二步 利用腐败标签即腐败矩阵提高网络的鲁棒性
我们初始化模型:
g
(
x
)
=
p
^
(
y
∣
x
;
θ
)
g(x)=\widehat{p}(y \mid x ; \theta)
g(x)=p
(y∣x;θ)
然后用信任集的损失:
ℓ
(
g
(
x
)
,
y
)
o
n
D
\ell(g(x),y)\quad on\quad D
ℓ(g(x),y)onD
和不信任集的损失:
(5):
ℓ
(
C
^
⊤
g
(
x
)
,
y
~
)
o
n
D
~
\ell\left(\widehat{C}^{\top} g(x), \widetilde{y}\right) \quad on \quad \widetilde D
ℓ(C
⊤g(x),y
)onD
解释一下在不信任集上的损失(5):
(6):
p
(
y
~
,
y
)
p
(
y
∣
x
)
=
p
(
y
~
∣
x
)
p(\widetilde{y},y)p(y \mid x)=p(\widetilde{y}\mid x)
p(y
,y)p(y∣x)=p(y
∣x)
由(6)式我们可以看出,虽然不信任集的损失(5)是在不信任标签
y
~
\widetilde{y}
y
下训练的,但得到
g
(
x
)
g(x)
g(x)却可以反映真实的预测标签
p
(
y
∣
x
)
p(y \mid x)
p(y∣x)。
所以我们最终可以得到一个很好的分类模型
g
(
x
)
=
p
^
(
y
∣
x
;
θ
)
g(x)=\widehat{p}(y \mid x ; \theta)
g(x)=p
(y∣x;θ)
论文参考:Using Trusted Data to Train Deep Networks on Labels Corrupted by Severe Noise. 2018 NIPS.