java计算混淆矩阵(分类指标:查准率P,查全率R,P和R的调和均值F1,正确率A)

【0】README

本文使用 java 计算混淆矩阵,并利用 混淆矩阵值计算 分类指标;通用分类指标有: 查准率,查全率,查准率和查全率的调和均值F1值,正确率, AOC, AUC等;本文计算前4个指标;(附源代码和结果截图)


【1】什么是混淆矩阵(借用自己PPT截图)


【2】查准率和查全率的区别

查准率:查准率表示选出的样本中有多少比例样本是正例(期望样本);

查全率:查全率表示有多少比例的正样本(期望样本)被选出来了;


【3】如何计算多分类混淆矩阵的评价指标(摘自周志华老师的机器学习,极力推荐大家买一本



【4】源码如下

// 计算混淆矩阵,并根据混淆矩阵计算 10次交叉验证下的 评估指标均值(精确度, 召回率, F值, 准确率 这4个指标) 
public class SingleConfusionMatrix {
//	C:\Users\pacoson\Desktop\confusion_matrix
	private static String dir = "C:" + File.separator + "Users" + File.separator + "pacoson" + File.separator 
			+ "Desktop" + File.separator + "confusion_matrix";
	
	private static double[][][] averages = new double[10][3][4]; // 10次交叉验证, 3个unknown(12, 24, 48),4个度量指标(查全率,查准率,F1,准确率)
	private static int fold = 0;
	private static int counter = 6;
	
	public static void main(String[] args) {
		File file = new File(dir);
		
		showFiles();
	}
	
	// show files. 
	public static void showFiles() {
		File[] files = new File(dir).listFiles();  
		
		for(File file: files) {  // 遍历 dir 目录下的所有文件
			String filename = file.getName(); 
			String prefix = filename.split("_")[0]; 
			if(prefix.length() > 1) continue;
			
			fold = Integer.valueOf(prefix); // 10次交叉验证的编号
			System.out.println("\n====== fold=" + fold + "======"); 
			
			double[][] array = new double[100][6]; // item 数组
			DataRead reader = new DataRead(file.getAbsolutePath());	
			reader.readDataToArray(1, 1, array); // 数据读取完毕 
			
			computeConfusion(0, array); // 预测长度12
//			computeConfusion(1, array); // 预测长度24
//			computeConfusion(2, array); // 预测长度48
			break;
		}
//		computeAverage();
	}
	
	// column = 1(12), 2(24), 3(48)
	static void computeConfusion(int column, double[][] array) { 
		// 1.计算confusion12/24/48: TP FN FP TN
		int[][] confusions = new int[6][4];
		counter = 6;
		
		for (int id = 0; id < 6; id++) {
			for (int i = 0; i < array.length; i++) {
				if(array[i][column] == id) {
					if(array[i][column+3] == id) // column = 1(12), 2(24), 3(48)
						confusions[id][0]++; // TP
					else
						confusions[id][1]++; // FN
				} else if(array[i][column] != id) {
					if(array[i][column+3] == id)
						confusions[id][2]++; // FP
					else
						confusions[id][3]++; // TN
				}
			}
		}
		
		// 2.计算 统计指标: 
		// 精确度P=TP/(TP+FP), 查准率
		// 召回率R= TP/(TP+FN), 查全率
		// f1值=2*P*R/(P+R)
		// 准确率=(TP+TN)/(TP+FN+FP+TN)
		double[][] metrices = new double[6][4]; // 精确度, 召回率, F1值, 分类准确率
		
		for (int i = 0; i < confusions.length; i++) { 
			double[] metric = metrices[i];
			int[] confusion = confusions[i];
			
			System.out.print("confusion matrix: TP, FN, FP, TN: ");
			for (int j = 0; j < confusion.length; j++) { // 打印每个混淆矩阵
				System.out.print(confusion[j] + ", ");
			}
			System.out.println();
			
			if(confusion[0] + confusion[2] != 0) // 分母不能为零.
				metric[0] =  (double)confusion[0] / (confusion[0] + confusion[2]); // 精确度
			
			if(confusion[0] + confusion[1] != 0) // 分母不能为零.
				metric[1] =  (double)confusion[0] / (confusion[0] + confusion[1]); // 召回率
			
			if(metric[0] + metric[1] != 0) // 分母不能为零.
				metric[2] =  (double)2*metric[0]*metric[1]/(metric[0] + metric[1]); // f值
			
			metric[3] =  (double)(confusion[0]+confusion[3]) / (confusion[0] + confusion[1]+ confusion[2]+ confusion[3]); // 准确率
			
			if(confusion[3] == 100) { // 如果 TN == 100, 表明没有这个类成员(TN表示真实类别不是该类别且预测类别也不是该类别,那如果总数为100,则没有模型没有选出该类别)。
				metric[3] = 0;
				counter--;
			}
		}
		
		// 3.求均值(宏精确度, 宏召回率, 宏F1值, 宏准确率)
//		double[] average = new double[4];
//		private static double[][][] averages = new double[10][3][4]; // 10次交叉验证, 3个unknown(12, 24, 48),4个度量指标(查全率,查准率,F1,准确率)
//		double[][] metrices = new double[6][4]; // 精确度, 召回率, F1值, 分类准确率
		
		double[] average = averages[fold][column];
		System.out.println("counter = " + counter);
		
		for (int j = 0; j < metrices[0].length; j++) {
			double sum = 0;
			for (int i = 0; i < metrices.length; i++) {
				 sum += metrices[i][j];
			}
			average[j] = sum/counter;
			System.out.print(average[j] + " ");
		}
		System.out.println();
	}
	
	public static void computeAverage() {
		System.out.println("\n ===  计算10折交叉验证的统计指标均值  === \n");
		
		// 预测长度为 12(j==0), 24(j==1), 48(j==2)
		for (int j = 0; j < 3; j++) { 			
			for (int k = 0; k < averages[0][0].length; k++) { // 4 个 items
				double sum = 0;
				for (int i = 0; i < averages.length; i++) { // 行(10折) 
					sum += averages[i][j][k];
				}
				System.out.print(sum/10 + " ");
			}
			System.out.println();
		}
	}
} 




Tips: 10次交叉验证实验,只需要调用其中的 computeAverage() 方法 就可以计算 其10次的均值了。(这里只求出了某次交叉实验的均值)


评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值