ROC曲线的理解和python绘制ROC曲线
ROC曲线的理解
考虑一个二分问题,即将实例分成正类(positive)或负类(negative)。对一个二分问题来说,会出现四种情况。如果一个实例是正类并且也被 预测成正类,即为真正类(True positive),如果实例是负类被预测成正类,称之为假正类(False positive)。相应地,如果实例是负类被预测成负类,称之为真负类(True negative),正类被预测成负类则为假负类(false negative)。
1. 尽量把相关的识别出来,TPR越大越好
2. 把不相关的识别为相关,FPR越小越好
从列联表引入两个新名词。其一是真正类率(true positive rate ,TPR), 计算公式为TPR=TP/ (TP+ FN),刻画的是分类器所识别出的 正实例占所有正实例的比例。另外一个是假正类率(false positive rate, FPR),计算公式为FPR= FP / (FP + TN),计算的是分类器错认为正类的负实例占所有负实例的比例。还有一个真负类率(True Negative Rate,TNR),也称为specificity,计算公式为TNR=TN/ (FP+ TN) = 1 - FPR。
FPR = FP/(FP + TN) 负样本中的错判率(假警报率)
TPR = TP/(TP + TN) 判对样本中的正样本率(命中率)
ACC = (TP + TN) / P+N 判对准确率
如何绘制ROC曲线
- 对于一个特定的分类器和测试数据集,显然只能得到一个分类结果,即一组FPR和TPR结果。但是分类器的一个重要功能“概率输出”,即表示分类器认为某个样本具有多大的概率属于正样本(或负样本)。如果不是在(0,1)之间,可以通过sigmod函数映射到(0,1)之间。我们可以通过改变阈值(discrimination threashold)来判断测试集的样本是正例,还是负例。
- 例如:我们通过一个分类器得出了所有样本的属于正例的输出概率,并根据测试集中的概率值从大到小的排列。图中共有20个测试样本,“Class”一栏表示每个测试样本真正的标签(p表示正样本,n表示负样本),“Score”表示每个测试样本属于正样本的概率。
- 对于图中的第4个样本,其“discrimination threashold”值为0.6,那么样本1,2,3,4都被认为是正样本,因为它们的“Score”值都大于等于0.6,而其他样本则都认为是负样本。每次选取一个不同的threshold,我们就可以得到一组FPR和TPR,即ROC曲线上的一点。当选取不同的discrimination threashold就可以得到不同的 FPR和TPR值。
AUC值
AUC(Area Under Curve)被定义为ROC曲线下的面积,显然这个面积的数值不会大于1。又由于ROC曲线一般都处于y=x这条直线的上方,所以AUC的取值范围在0.5和1之间。使用AUC值作为评价标准是因为很多时候ROC曲线并不能清晰的说明哪个分类器的效果更好,而作为一个数值,对应AUC更大的分类器效果更好。
- 如上,是三条ROC曲线,在0.23处取一条直线。那么,在同样的低FPR=0.23的情况下,红色分类器得到更高的PTR。也就表明,ROC越往上,分类器效果越好。我们用一个标量值AUC来量化他。
AUC的物理意义
- 假设分类器的输出是样本属于正类的socre(置信度),则AUC的物理意义为,任取一对(正、负)样本,正样本的score大于负样本的score的概率。
ROC曲线的优势
ROC曲线有个很好的特性:当测试集中的正负样本的分布变化的时候,ROC曲线能够保持不变。在实际的数据集中经常会出现类不平衡(class imbalance)现象,即负样本比正样本多很多(或者相反),而且测试数据中的正负样本的分布也可能随着时间变化。
- AUC = 1,是完美分类器,采用这个预测模型时,不管设定什么阈值都能得出完美预测。绝大多数预测的场合,不存在完美分类器。
- 0.5 < AUC < 1,优于随机猜测。这个分类器(模型)妥善设定阈值的话,能有预测价值。
- AUC = 0.5,跟随机猜测一样(例:丢铜板),模型没有预测价值。
- AUC < 0.5,比随机猜测还差;但只要总是反预测而行,就优于随机猜测。
- 在上图中,(a)和(c)为ROC曲线,(b)和(d)为Precision-Recall曲线。(a)和(b)展示的是分类其在原始测试集(正负样本分布平衡)的结果,(c)和(d)是将测试集中负样本的数量增加到原来的10倍后,分类器的结果。可以明显的看出,ROC曲线基本保持原貌,而Precision-Recall曲线则变化较大。
计算AUC
- 第一种方法:AUC为ROC曲线下的面积,那我们直接计算面积可得。面积为一个个小的梯形面积之和。计算的精度与阈值的精度有关。
- 第二种方法:根据AUC的物理意义,我们计算正样本score大于负样本的score的概率。取N*M(N为正样本数,M为负样本数)个二元组,比较score,最后得到AUC。时间复杂度为O(N*M)。
- 第三种方法:与第二种方法相似,直接计算正样本score大于负样本的概率。我们首先把所有样本按照score排序,依次用rank表示他们,如最大score的样本,rank=n(n=N+M),其次为n-1。那么对于正样本中rank最大的样本,rank_max,有M-1个其他正样本比他score小,那么就有(rank_max-1)-(M-1)个负样本比他score小。其次为(rank_second-1)-(M-2)。最后我们得到正样本大于负样本的概率为:
时间复杂度为O(N+M)。
python 绘制ROC曲线代码
- 分类器的分类的结果:predStrengths:
- [[-0.646419 0.53886223 0.91726555 0.21712009 -0.69768794 1.22181293
1.22748297 0.58145314 -0.40165729 0.03508613 0.27123572 0.59407783
1.53203035 0.64819347 1.04739323 -0.40165729 -1.02662219 0.5606821
0.34364609 -0.40784481 -0.02469954 1.53203035 1.1676973 1.0995114
0.73717581 0.23749438 0.52166747 0.85052522 0.5606821 1.69342306
-1.02662219 -0.03331166 0.95841088 0.6538635 -0.40165729 -1.58160132
-0.32315478 -0.69921001 1.22748297 0.23749438 -1.58160132 0.23749438
-0.03331166 -0.49069632 -0.81111385 -1.58160132 -1.27705394 1.22748297
-0.10329743 -1.33116957 0.91726555 0.59407783 0.91726555 0.5606821
0.20691998 0.27123572 0.5606821 -0.02469954 -0.70343447 -1.58160132
1.1676973 0.95841088 -0.37284836 -0.19233647 0.21712009 0.68306018
0.18462128 1.53203035 0.20077111 1.22748297 0.18537621 1.53203035
-0.01293737 -0.32647673 0.5606821 -0.03331166 1.31646531 1.69342306
-0.07846957 0.02322857 -0.70620467 1.2613064 0.21712009 -1.40405878
0.59407783 1.47224468 -0.40165729 0.5606821 0.89862521 1.22748297
0.40165729 0.20691998 1.1550726 -0.07279954 0.79073955 1.41702908
0.62790127 -0.40784481 1.1676973 -0.40784481 1.31646531 -0.95663642
0.4544483 0.90429525 0.29848817 0.89862521 0.50089643 1.1676973
-0.86522948 0.21712009 -0.35372918 0.85052522 1.22748297 -0.01293737
0.38747052 1.56689706 1.99797044 0.54030781 0.44691059 0.80958618
1.1676973 -1.22493577 1.22748297 -0.27513129 0.84485519 1.40550435
-0.38301695 0.48217959 0.07292214 0.86522948 0.5606821 0.5273375
1.22748297 -1.27705394 -0.40165729 -0.07846957 1.53203035 0.80958618
0.90429525 0.85052522 -0.12406847 -1.58160132 1.53203035 0.36264266
0.98760756 1.38320565 0.91726555 -0.24749538 1.42269911 1.31646531
-1.00798185 0.11221091 1.1550726 1.93818477 0.84485519 1.22748297
1.11413356 1.10095697 1.22181293 -0.36264266 0.5606821 0.58145314
-0.57967867 0.64819347 0.16432908 1.22748297 1.53203035 0.27366031
-0.63171474 0.91726555 -0.65827656 0.05809528 -0.65827656 -0.77909388
0.5606821 1.36456531 -1.24647954 -0.24026458 -0.65208904 -0.40165729
0.5606821 -1.40405878 1.11413356 1.36456531 -0.40784481 0.90429525
0.83188488 0.6538635 0.58145314 0.64654161 -0.01902951 1.69342306
-1.33116957 0.34364609 1.31646531 0.02322857 -0.0976274 1.1198036
0.21712009 1.22181293 -0.27513129 0.81111385 1.53203035 -0.07846957
-0.65208904 0.44124055 0.5606821 0.36264266 1.1550726 0.07292214
-0.03331166 0.89862521 0.84485519 1.20884263 1.53203035 0.5606821
-0.63790227 1.1676973 0.84485519 -0.39598725 0.9932776 -0.35372918
0.64819347 1.22748297 1.53203035 -0.65827656 0.5606821 -1.02662219
0.93811868 -0.39598725 -1.58160132 0.5606821 0.52166747 0.79073955
-0.32647673 1.22748297 1.1676973 0.27123572 -0.07424511 -1.33116957
0.91726555 0.54891992 0.6538635 0.59407783 -1.58160132 -0.40165729
-0.65208904 -0.02469954 1.53203035 -0.40784481 0.54030781 0.21712009
-0.95663642 -0.49069632 -0.01902951 1.1198036 -0.65827656 1.53203035
0.20691998 0.49937436 1.13595343 -0.01293737 -0.29542349 0.84485519
0.83188488 0.34364609 0.54030781 0.95675902 0.90429525 0.83188488
-0.65208904 -0.40165729 -1.58160132 0.38313956 -0.32924692 0.69768794
1.47224468 1.22748297 0.83188488 1.42269911 -0.65827656 -0.35372918
-0.01293737 1.53203035 0.95841088 -1.04249592 0.23749438 0.5606821
1.20719077 0.91726555 -0.10329743 -0.57967867 0.27123572 1.69342306
0.05809528 -0.65208904 -1.02662219 0.27123572 0.5606821 ]] - 实际的标签:classLabels
- [-1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, 1.0, 1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0, -1.0, -1.0, 1.0, -1.0]
def plotROC(predStrengths, classLabels):
import matplotlib.pyplot as plt
cur = (1.0,1.0) #cursor
ySum = 0.0 #variable to calculate AUC
numPosClas = sum(array(classLabels)==1.0)
yStep = 1/float(numPosClas); xStep = 1/float(len(classLabels)-numPosClas)
sortedIndicies = predStrengths.argsort()#get sorted index, it's reverse
fig = plt.figure()
fig.clf()
ax = plt.subplot(111)
#loop through all the values, drawing a line segment at each point
for index in sortedIndicies.tolist()[0]:
if classLabels[index] == 1.0:
delX = 0; delY = yStep;
else:
delX = xStep; delY = 0;
ySum += cur[1]
#draw line from cur to (cur[0]-delX,cur[1]-delY)
ax.plot([cur[0],cur[0]-delX],[cur[1],cur[1]-delY], c='b')
cur = (cur[0]-delX,cur[1]-delY)
ax.plot([0,1],[0,1],'b--')
plt.xlabel('False positive rate'); plt.ylabel('True positive rate')
plt.title('ROC curve for AdaBoost horse colic detection system')
ax.axis([0,1,0,1])
plt.show()
print "the Area Under the Curve is: ",ySum*xStep