auc计算逻辑
实现auc,python
# -*- coding utf-8 -*-
# @Author:
# @Date: 2021/3/3 11:07 上午
from sklearn.metrics import roc_auc_score
import time
def auc(y_true, y_score):
assert len(y_true) == len(y_score)
m = sum(y_true)
n = len(y_true) - m
bins = 200
bin_width = 1 / bins
p_bins = [0 for _ in range(bins)]
n_bins = [0 for _ in range(bins)]
# 分桶
for i in range(m + n):
index = int(y_score[i] / bin_width)
if y_true[i] == 1:
p_bins[index] += 1
else:
n_bins[index] += 1
accumulated_n = 0
pair = 0
for i in range(bins):
pair += p_bins[i] * accumulated_n + p_bins[i] * n_bins[i] * 0.5
# 滞后一位
accumulated_n += n_bins[i]
return pair / (m * n)
label = [1, 0, 1, 1, 0, 1, 0, 0, 1]
score = [0.5, 0.4, 0.3, 0.7, 0.4, 0.3, 0.7, 0.4, 0.3]
# label = label*1000000
# score = score*1000000
score = [0.5, 0.3, 0.2, 0.8, 0.7]
label = [0, 1, 0, 1, 1]
begin = time.time()
auc1 = auc(label, score)
end = time.time()
print(auc1)
print(end - begin)
begin = time.time()
print(roc_auc_score(label, score))
end = time.time()
print(end - begin)
实现auc,java
package com.xueqiu.infra.xdc.hive.udf;
import org.apache.hadoop.hive.ql.exec.UDF;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import java.util.Arrays;
/**
*
* @author
*/
/**
* 计算单特征auc
*/
public class RocAucScore extends UDF {
/**
* @param yTrue label,#分隔
* @param yScore 预测分数,#分隔
* @return
* @throws HiveException
*/
public String evaluate(String yTrue, String yScore) throws HiveException{
int bins=200;
if (yTrue.equals(" ") || yTrue.length() == 0 || yScore.equals(" ") || yScore.length() == 0) return "0";
String [] labels = yTrue.split("#");
String [] scores = yScore.split("#");
assert (labels.length != scores.length);
int arrayLength = labels.length;
int[] realLabels = new int[arrayLength];
float[] predictScores = new float[arrayLength];
for(int i=0; i < arrayLength; i++){
realLabels[i] = Integer.parseInt(labels[i]);
predictScores[i] = Float.parseFloat(scores[i]);
}
double binWidth = 1.0/bins;
double[] positiveBins = new double[bins];
double[] negativeBins = new double[bins];
int positiveNum = Arrays.stream(realLabels).sum();
long negativeNum = arrayLength - positiveNum;
for(int j=0; j<arrayLength; j++){
int index = (int) (predictScores[j]/binWidth);
if(realLabels[j] == 1){
positiveBins[index] += 1;
}else{
negativeBins[index] += 1;
}
}
double accumulatedNum = 0.0;
double pair = 0.0;
for (int k=0; k<bins; k++){
pair += positiveBins[k]*accumulatedNum + positiveBins[k]*negativeBins[k]*0.5;
accumulatedNum += negativeBins[k];
}
return String.valueOf(pair/(positiveNum * negativeNum));
}
public static void main(String[] args) throws Exception {
String yTrue = "1#0#1#1#0#1#0#0#1";
String yScore = "0.5#0.4#0.3#0.7#0.4#0.3#0.7#0.4#0.3";
StringBuilder sb1 = new StringBuilder();
StringBuilder sb2 = new StringBuilder();
int num = 1000000;
for (int i = 0; i < num; i++) {
sb1.append(yTrue);
sb2.append(yScore);
if(i<num-1){
sb1.append("#");
sb2.append("#");
}
}
RocAucScore ras = new RocAucScore();
long startTime = System.currentTimeMillis();
System.out.println(ras.evaluate(sb1.toString(), sb2.toString()));
long endTime = System.currentTimeMillis();
long totalTime = endTime - startTime;
System.out.println(totalTime/1000);
}
}