covariate shift现象以及解决方法

最近在重读paper《Batch Normalization》的时候,发现它在文中反复提到了一个概念“covariate shift”,而batch-Normalization的提出就是为了解决神经网络中(尤其是比较深的网络中的covariate shift现象)。我对这个概念很感兴趣,就花费时间去查了一些,今天在这里总结一下学到的东西。

首先我要来解释一下什么叫做covariate shift现象,这个指的是训练集的数据分布和预测集的数据分布不一致,这样的情况下如果我们在训练集上训练出一个分类器,肯定在预测集上不会取得比较好的效果。这种训练集和预测集样本分布不一致的问题就叫做“covariate shift”现象。比方说,我想训练一个模型根据人的血液样本来判断其有没有得血液病,对于负样本肯定就是收集一些血液病人的血液,但是对于正样本来说的话,其采样一定要合理,所采样例一定要满足整个人群中的分布。如果只采特定领域人群(比方说学校的学生)的血液作为正样本,那么我最终训练得到的模型,很难在所有人群中取得不错的效果,因为真实的预测集中学生只是正常人群中很少的一部分。(这个现象在迁移学习中也很常见)

要解决“covariate shift”问题,其实就是重新给训练集中的数据赋予一个新的权重即Reweight操作,比方说对于样本 x i x_i xi,它在训练集中的分布是 q ( x i ) q(x_i) q(xi),在预测集中的真实分布是 p ( x i ) p(x_i) p(xi),那么它的新权重就是 p ( x i ) q ( x i ) \frac{p(x_i)}{q(x_i)} q(xi)p(xi)。那么现在的问题就变成了如何确定样本 x i x_i xi在训练集和预测集中的真实分布。其实用的方法特别的巧妙,同样用的是机器学习的方法:Logistic Rgression,就是随机的从训练集和测试集随机的抽取样本,根据他们的来源不同,把来自训练集的样本标注为1,把来自预测集的样本标注为-1。把这份数据分成新的训练集和测试集,在训练集上训练模型,然后看该训练好的模型在测试集上的表现,如果表现的好,说明它能够很好的区分来自之前训练集和测试集的数据,就说明这些数据的分布不一致,反之亦然。具体的计算公式如下:
p ( z = 1 ∣ x i ) = p ( x i ) p ( x i ) + q ( x i ) p(z=1|x_i)=\frac{p(x_i)}{p(x_i)+q(x_i)} p(z=1xi)=p(xi)+q(xi)p(xi) //z=1表示该样本来自于之前的预测集分布 p p p,z=-1表示该样本来自于之前的训练集分布 q q q。当训练好了Logistic Regression分类器之后, p ( z = 1 ∣ x i ) = 1 1 + e − f ( x i ) p(z=1|x_i)=\frac{1}{1+e^{-f(x_i)}} p(z=1xi)=1+ef(xi)1,然后就很容易推出对于样本 x i x_i xi来说,它reweight的权值是 p ( z = 1 ∣ x i ) p ( z = − 1 ∣ x i ) = e f ( x i ) \frac{p(z=1|x_i)}{p(z=-1|x_i)}=e^{f(x_i)} p(z=1xi)p(z=1xi)=ef(xi),其中的 f ( x i ) f(x_i) f(xi)就是我们训练出来的分类器。

貌似感觉已经把covariate shift问题的解决方案讲完了,其实还有一个大前提,就是该用什么样的指标来判断是否已经出现了covariate shift现象(只有判断出现了covariate shift现象之后,才需要reweight样本权重,否则就不用了)。这里使用的指标叫做MCC(Matthews correlation coefficient),这个指标本质上是用一个训练集数据和预测集数据之间的相关系数,取值在[-1,1]之间,如果是1就是强烈的正相关,0就是没有相关性,-1就是强烈的负相关。它的具体计算和confusion matrix概念相关,下面来列举几个和confusion matrix相关的概念:
TP(True Positive):真实为1,预测为1
FN(False Negative):真实为1,预测为0
FP(False Positive):真实为0,预测为1
TN(True Negative):真实为0,预测为0
M c c = T P ∗ T N − F P ∗ F N ( T P + F P ) ( T P + F N ) ( T N + F P ) ( T N + F N ) Mcc=\frac{TP*TN-FP*FN}{\sqrt{(TP+FP)(TP+FN)(TN+FP)(TN+FN)}} Mcc=(TP+FP)(TP+FN)(TN+FP)(TN+FN) TPTNFPFN
(PS:衡量二分类效果的几个指标,ACC(准确率),Rec(召回率),F值,AUC,MCC,它们各自对应了自己的应用场景)
对于Logistic Regression分类器,计算其Mcc值,一般认为如果该值大于0.2,说明预测集和测试集相关度高,也就是说明分类器容易把在训练集上学习到的经验应用在预测集上,也就是说明出现了covariate shift的现象;如果小于0.2,就没有出现covariate shift现象。

  • 14
    点赞
  • 50
    收藏
    觉得还不错? 一键收藏
  • 20
    评论
评论 20
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值