Weka中实现计算ROC的是ThresholdCurve类的getCurve函数
(1)getCurve函数
@param 一般是十折交叉得到的predictions
@param 正例的类标记,多元分类问题,设置某个类别为正例,其他类别就是负例
@return datapoints
if (pred.actual() == classIndex) {
totPos += pred.weight();//累计正例权重
} else {
totNeg += pred.weight();//累计负例权重
}
Instances insts = makeHeader();
此时得到insts为:
@relation ThresholdCurve
@attribute 'True Positives' numeric
@attribute 'False Negatives' numeric
@attribute 'False Positives' numeric
@attribute 'True Negatives' numeric
@attribute 'False Positive Rate' numeric
@attribute 'True Positive Rate' numeric
@attribute Precision numeric
@attribute Recall numeric
@attribute Fallout numeric
@attribute FMeasure numeric
@attribute 'Sample Size' numeric
@attribute Lift numeric
@attribute Threshold numeric
@data
Instances insts = makeHeader();
int[] sorted = Utils.sort(probs);//这里排序是为了把正例和负例各个放在两端,升序排列后
//负例聚集到前端,正例聚集到尾端,sorted存储的是升序的Index而非元素本身
TwoClassStats tc = new TwoClassStats(totPos, totNeg, 0, 0);
double threshold = 0;
double cumulativePos = 0;
double cumulativeNeg = 0;
for (int i = 0; i < sorted.length; i++) {
if ((i == 0) || (probs[sorted[i]] > threshold)) {
tc.setTruePositive(tc.getTruePositive() - cumulativePos);
tc.setFalseNegative(tc.getFalseNegative() + cumulativePos);
tc.setFalsePositive(tc.getFalsePositive() - cumulativeNeg);
tc.setTrueNegative(tc.getTrueNegative() + cumulativeNeg);
threshold = probs[sorted[i]];
insts.add(makeInstance(tc, threshold));
cumulativePos = 0;
cumulativeNeg = 0;
if (i == sorted.length - 1) {
break;
}
}
NominalPrediction pred = (NominalPrediction) predictions.get(sorted[i]);
if (pred.actual() == classIndex) {
cumulativePos += pred.weight();
} else {
cumulativeNeg += pred.weight();
}
}
// make sure a zero point gets into the curve
//确保0点在曲线上
if (tc.getFalseNegative() != totPos || tc.getTrueNegative() != totNeg) {
tc = new TwoClassStats(0, 0, totNeg, totPos);
threshold = probs[sorted[sorted.length - 1]] + 10e-6;
insts.add(makeInstance(tc, threshold));
}
return insts;
利用Weka画ROC和计算AUC的方法:
来自《数据挖掘与机器学习:WEKA应用技术与实践(第二版)》
public static void test1()throws Exception{
ArffLoader loader=new ArffLoader();
loader.setSource(new File("./data/weather.nominal.arff"));
Instances data=loader.getDataSet();
data.setClassIndex(data.numAttributes()-1);
Classifier classifier =new NaiveBayes();
Evaluation eval=new Evaluation(data);
eval.crossValidateModel(classifier, data, 10, new Random(1));
ThresholdCurve tc=new ThresholdCurve();
int classIndex=0;
Instances curve =tc.getCurve(eval.predictions(),classIndex);
PlotData2D plotdata=new PlotData2D(curve);
plotdata.setPlotName(curve.relationName());
plotdata.addInstanceNumberAttribute();
ThresholdVisualizePanel tvp=new ThresholdVisualizePanel();
tvp.setROCString("(Area under ROC=" +
Utils.doubleToString(ThresholdCurve.getROCArea(curve), 4)+")");
tvp.setName(curve.relationName());
boolean [] cp=new boolean[curve.numInstances()];
for(int i=0;i<cp.length;i++)
cp[i]=true;
plotdata.setConnectPoints(cp);
tvp.addPlot(plotdata);
final JFrame jf=new JFrame("WEKA ROC: "+tvp.getName());
jf.setSize(500,400);
jf.getContentPane().setLayout(new BorderLayout());
jf.getContentPane().add(tvp, BorderLayout.CENTER);
jf.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
jf.setVisible(true);
}
注意如果把交叉验证的Random(1)改为Random(1234)(书中的代码)的话得到的图稍有不同:
另一个类似的方法也可参考:
public static void test2() throws Exception {
ArffLoader loader=new ArffLoader();
loader.setSource(new File("./data/weather.nominal.arff"));
Instances data=loader.getDataSet();
data.setClassIndex(data.numAttributes() - 1);
/*
* 训练分类器并用十字交叉验证法来获得Evaluation对象
* 注意这里的方法与我们在上几节中使用的验证法是不同。
*/
Classifier cl = new NaiveBayes();
Evaluation eval = new Evaluation(data);
eval.crossValidateModel(cl, data, 10, new Random(1));
/*
* 生成用于得到ROC曲面和AUC值的Instances对象
* 顺带打印了一些其它信息,用于在SPSS中生成ROC曲面
* 如果我们查看weka源码就会知道这个Instances对象包含了很多分类的结果信息
* 例如:FMeasure、Recall、Precision、True Positive Rate、
* False Positive Rate等等。我们可以用这些信息绘制各种曲面。
*/
ThresholdCurve tc = new ThresholdCurve();
// classIndex is the index of the class to consider as "positive"
int classIndex = 0;
Instances result = tc.getCurve(eval.predictions(), classIndex);
System.out.println("The area under the ROC curve: " + eval.areaUnderROC(classIndex));
/*
* 在这里我们通过结果信息Instances对象得到包含TP、FP的两个数组
* 这两个数组用于在SPSS中通过线图绘制ROC曲面
*/
int tpIndex = result.attribute(ThresholdCurve.TP_RATE_NAME).index();
int fpIndex = result.attribute(ThresholdCurve.FP_RATE_NAME).index();
double[] tpRate = result.attributeToDoubleArray(tpIndex);
double[] fpRate = result.attributeToDoubleArray(fpIndex);
System.out.println("TPRate "+Arrays.toString(tpRate));
System.out.println("FPRate "+Arrays.toString(fpRate));
/*
* 4.使用结果信息instances对象来显示ROC曲面
*/
ThresholdVisualizePanel vmc = new ThresholdVisualizePanel();
// 这个获得AUC的方式与上面的不同,其实得到的都是一个共同的结果
vmc.setROCString("(Area under ROC = " +
Utils.doubleToString(tc.getROCArea(result), 4) + ")");
vmc.setName(result.relationName());
PlotData2D tempd = new PlotData2D(result);
tempd.setPlotName(result.relationName());
tempd.addInstanceNumberAttribute();
boolean [] cp=new boolean[result.numInstances()];
for(int i=0;i<cp.length;i++)
cp[i]=true;
tempd.setConnectPoints(cp);
vmc.addPlot(tempd);
// 显示曲面
String plotName = vmc.getName();
final javax.swing.JFrame jf =
new javax.swing.JFrame("Weka Classifier Visualize: " + plotName);
jf.setSize(500, 400);
jf.getContentPane().setLayout(new BorderLayout());
jf.getContentPane().add(vmc, BorderLayout.CENTER);
jf.addWindowListener(new java.awt.event.WindowAdapter() {
public void windowClosing(java.awt.event.WindowEvent e) {
jf.dispose();
}
});
jf.setVisible(true);
}