需求描述
对于二元分类算法,在训练完模型后,用testData对模型进行评估,可以得到*(score,label)的数据。在存在shreshold的情况下,我们认为score>=shreshold的情况下,该数据为positive*,相反则为negative。进一步得到*(prediction,label)的数据(prediction就是预测的label值),比较prediction和label*是否一致,就可以知道模型的好坏如何,也就是这里提到的二元分类算法的评估方法。(wikipedia关于二元分类的评估方法)。评估结果的核心是得到每条数据应该属于TP/TN/FP/FN具体哪一类。
通常,我们需要通过调试shreshold的值(在0和1之间)去寻找一个比较合理的结果。因此在elemental中提供了从0.01到0.99变化的100个不同的值对应的100个评估结果。总之,我们需要计算100遍评估结果。
原解决方案
shreshold依次从0.01变化到0.99,计算每次shreshold值应该具有的评估结果,每次都需要重新计算每条数据。具体逻辑代码如下:
val scoreAndLabelDatas:RDD[(double,double)]
(shreshold<- 0.01 to 0.99){
evalutionMetric(scoreAndLabelDatas,shreshold)
}
def evluationMetric(scoreAndLabelDatas:RDD[(double,double)],threshold:Double=0.5) {
scoreAndLabelDatas.map {
case (score,label)=> {
(score,label) match {
case tp if(score>=threshold && label==1.0)
case fp if(score>=threshold && label==0.0)
case tn if(score<threshold && label==0.0)
case fn if(score<threshold && label==1.0)
}
}
}
...
...
}
整个结果的计算时间就是单次evluationMetric获得结果的时间的100倍,即使scoreAndLabelDatas已经缓存的情况下。
优化方案
我们有必要重复计算那么多次么?
先看一条数据*(score,label)*,无论_shreshold_怎么变化,_score_和_label_都不会发生任何变化。
那么shreshold对于这条数据属于哪一类的影响因子是什么呢?
假设该数据为(0.25,1.0),当shreshold小于等于0.25的时,该数据属于TP,大于0.25时,该数据属于FN,也就是说0.25只是这条数据的一个跳变点(从TP跳变为FN的点),记为跳变值为25的TP Changer;
同理,如果数据为(0.36,0.0),当shreshold小于等于0.36的时,该数据属于FP,大于0.36时,该数据属于TN,0.36就是这条数据的跳变点(从FP跳变为TN的点),记为跳变值为36的FP Changer。
每条数据都存在一个跳变点,不是TP的跳变点就是FP的跳变点
因此利用跳变点的特性,我们修改计算的过程如下:
- 1.计算每条数据的跳变点的值
val (changerValue,eitherTpOrFp)=(score,label) match {
case (score,1.0)=>((score*100).toInt,Left(tp))
case (score,0.0)=>((score*100).toInt,Right(fp))
}
- 2.根据其属于TP Changer还是FP Changer,上面的结果分为两大类。每类合计出其具体Map[ChangerValue->Count],即每个跳变值具有多少个
val changerValues:RDD[(Int,Int)]
val changerMap=changerValues.reduceByKey(_+_).toMap
- 3.这样得到了TpChangerMap和FpChangerMap,每个集合都具备这样的性质。假设label=1.0的数据有10条,TpChangerMap为(1->1,23->4,61->2,79->3)。那么从0开始,所有数据都应该是TP,此时(TP,FN)的值为(10,0),当跳变值变化到2的时候,所有跳变值为1的TP都转化为了FN,(TP,FN)的值为(9,1)。以此类推,得到下面的数据
0时,(TP,FN)为(10,0);
2时,(TP,FN)为(9,1);10减去 跳变值为1的count=9
24时,(TP,FN)为(5,5);9减去 跳变值为23的count=5
62时,(TP,FN)为(3,7);5减去 跳变值为61的count=3
80时,(TP,FN)为(0,10);3减去 跳变值为79的count=0
因此跳变值从0变化到100,可以得到101组(TP,FN)的值,且每次值都是前一次结果的值减去跳变值对应的数据。换句话说,我们可以一次性得到shreshold变化100次的所有(TP,FN)的值,同理另外所有的(FP,TN)的值也可以一次性得到。
val allValues = (0 to thresholdNum)
.scanLeft(new Array[(Int, Int, Int, Int)](101)) {
case (values, thresholdIdx) => {
values.zipWithIndex.map {
case (value, idx) => {
val currentVal = if (null == value) (0, 0, 0, 0) else value
changeValues(currentVal, TpChangerMap(thresholdIdx),FpChangerMap(thresholdIdx))
}
}
}
}
.tail
对比###
假设_testData_的规模是M条数据,分布在N个partition上,并且假定所有数据都可以完美平行处理,那么:
旧方案的时间开销为:
单次TP/FN/FP/TN所有值的时间开销为:M/N
100次时间开销为:M/N*100
修改方案的时间开销为:
所有跳变值的时间开销为:M/N
获取跳变值Map的开销为:O(M/N)
获取所有TP/FN/FP/TN的时间开销为:O(100)
合计:M/N+O(M/N)+O(100)
由于M/N是大头,效率将大为提升。考虑到还有其他计算的影响,实际提升基本上在__十倍左右__