Knn算法讲解

本文深入介绍了K-近邻(K-Nearest Neighbor, KNN)算法,包括其基本原理、距离计算方法(如欧式距离和曼哈顿距离)以及懒惰学习的概念。通过实例展示了如何使用KNN算法对鸢尾花数据集进行分类,并提供了Java代码实现,包括数据预处理、训练集与测试集划分、预测及精度评估。此外,还讨论了K值的选择对预测准确性的影响。
摘要由CSDN通过智能技术生成

简介:

  • 全名K-nearest neighber 算法,它是一个根据训练集并且非规则的分类算法。
  • Cover 和 Hart 在 1968年提出的最初的邻近算法
  • 通过测量不同特征值之间的距离方法进行分类
  • 属于懒惰学习方法Lazy learner
  • Eager Learners:接收测试数据之前,对已有的数据构建一个分类模型。
  • Lazy Learners (instance-based learners) : 在训练阶段仅仅把样本保存下来,训练时间开销为零,待接收到测试样本后在进行处理。

算法原理:

存在一个样本集合(训练样本集),并且样本中的每个数据都有标签(也是类别)。

输入没有类别标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后提取样本集最相似的K个分类标签。

选择这个K个标签中次数最多的分类为新数据的分类。

举例:

五列数据分别对应花萼长度、花萼宽度、花瓣长度、花瓣宽度和种类,其中种类分别为山鸢尾、变色鸢尾和维吉尼亚鸢尾三个类别。

5.1,3.5,1.4,0.2,Iris-setosa

4.9,3.0,1.4,0.2,Iris-setosa

4.7,3.2,1.3,0.2,Iris-setosa

4.6,3.1,1.5,0.2,Iris-setosa

7.0,3.2,4.7,1.4,Iris-versicolor

6.4,3.2,4.5,1.5,Iris-versicolor

6.9,3.1,4.9,1.5,Iris-versicolor

5.5,2.3,4.0,1.3,Iris-versicolor

6.3,2.5,5.0,1.9,Iris-virginica

6.5,3.0,5.2,2.0,Iris-virginica

6.2,3.4,5.4,2.3,Iris-virginica

5.9,3.0,5.1,1.8,Iris-virginica

我在这里只选用了每种花的四个数据样本,实际上数据样本可以非常的大。为了方便起见我把它们画在一个平面上(距离可能不准,请理解思路)。

 然后有一个测试数据进来:5.0,3.0,1.6,0.2 不知道它是那种花,因此我们给定一个k值等于三,找出他周围最相邻的三个数据。图中黑球像这样:

 我们直观的看出当 K = 3 时,其中有两个红球,一个蓝球,也就是说我们的待测数据跟红球样本相似度更大,所以判定此花是Iris-setosa。

问题来了:K值是我们给定的,哪我们怎样去判断距离呢?

第一步:明确距离坐标的含义,就是事物的特征值。比如此处的花萼长度、花萼宽度、花瓣长度、花瓣宽度就是坐标中的 x,y,z. 多了画不出来没关系,下面我们将介绍它的距离计算。

第二步:距离计算,我介绍常用的两种。

欧式距离:大家再熟悉熟悉不过了,例如:x=(x1,x2,...,xn) , y = (y1,y2,...,yn) d(x,y) := \sqrt{(x_{1}-y_{1})^{2}+(x_{2}-y_{2})^{2}+\cdot \cdot \cdot +(x_{n}-y_{n})^{2}}

曼哈顿距离:x = (x1,x2) , y = (y1,y2) d = | x1 - y1 | + | x2 - y2 |

 图中橘色的线段便是欧氏距离,而紫色的便是曼哈顿距离。两种距离各有千秋,像我们这道题:给定花的特征判断花的种类,欧式距离比较好,如果是城乡规划路径上的问题曼哈顿距离优势更准确,因为你不能直接越过楼房算位移,我们需要路程。

投票排序:

我们找出了 3 个距待测样本最近的数据后就要进行投票了。比如 第一个球是红色,红球+1,第二球是蓝色,蓝球+1,第三个球是红色红球+1。最终的结果是 :红(2票),蓝(1票)。代码中体现的是花的种类对应的数组下标。

根据票数我们就知道测试的结果是红球。可这样还不行,这只是预测,我们还需要来估算预测的准确率是多少,它与我们的K选取有关。比如:100个未知数据,我们预测的结果和真实的结果相比较,发现对了99个,那么预测的准确率便是 0.99.

代码:

package machinelearning.knn;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.*;

public class KnnClassification {

	/**
	 * 曼哈顿距离
	 */
	public static final int MANHATTAN = 0;

	/**
	 * 欧式距离
	 */
	public static final int EUCLIDEAN = 1;

	/**
	 * 距离测量
	 */
	public int distanceMeasure = EUCLIDEAN; //现将距离测量置为1

	/**
	 * 生成随机距离
	 */
	public static final Random random = new Random();

	/**
	 * 邻居数目
	 */
	int numNeighbors = 7;

	/**
	 * The whole dataset.
	 */
	Instances dataset;

	/**
	 * 训练场地,由数据索引表示
	 */
	int[] trainingSet;

	/**
	 * 测试集,由数据索引表示
	 */
	int[] testingSet;

	/**
	 * 预测
	 */
	int[] predictions;

    // 导入数据
	public KnnClassification(String paraFilename) {
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			// The last attribute is the decision class.
			dataset.setClassIndex(dataset.numAttributes() - 1);
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read \'" + paraFilename
					+ "\' in KnnClassification constructor.\r\n" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor

	// 获得随机索引,数据随机化
	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];

		// Step 1. 初始化resultIndices数组
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		// Step 2. 随机交换
		int tempFirst, tempSecond, tempValue;
		// for 循环的目的是打乱数组中数字的顺序
		for (int i = 0; i < paraLength; i++) {
			// 生成两个随机索引,tempFirst,tempSecond
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			// Swap. 交换
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices
	
	// 把导入的数据一分为二,训练场地和测试场地
	public void splitTrainingTesting(double paraTrainingFraction) {
		// 
		int tempSize = dataset.numInstances();
		//得到一个测试数组
		int[] tempIndices = getRandomIndices(tempSize);
		int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

		trainingSet = new int[tempTrainingSize];
		testingSet = new int[tempSize - tempTrainingSize];

		for (int i = 0; i < tempTrainingSize; i++) {
			trainingSet[i] = tempIndices[i];
		} // Of for i

		for (int i = 0; i < tempSize - tempTrainingSize; i++) {
			testingSet[i] = tempIndices[tempTrainingSize + i];
		} // Of for i
	}// Of splitTrainingTesting

	// 把预测的结果放在predictions数组中集中起来,后面做准确率计算
	public void predict() {
		predictions = new int[testingSet.length];
		for (int i = 0; i < predictions.length; i++) {
			predictions[i] = predict(testingSet[i]);
		} // Of for i
	}// Of predict

	// 进行预测
	public int predict(int paraIndex) {
		// 1. 找出相邻的 K 个邻居
		int[] tempNeighbors = computeNearests(paraIndex);
		
		// 2. 进行投票找出种类标签(数组下标表示,所以返回的是一个int型)
		int resultPrediction = simpleVoting(tempNeighbors);

		return resultPrediction;
	}// Of predict

       // 距离测量,两种选择,曼哈顿,欧式距离
	public double distance(int paraI, int paraJ) { // 猜测I和J是两个 
		double resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;

		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// Of switch

		return resultDistance;
	}// Of distance

	// 精确度
	public double getAccuracy() {
		
		double tempCorrect = 0;
		
		// 把正确的预测统计出来
		for (int i = 0; i < predictions.length; i++) {
			if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i

		return tempCorrect / testingSet.length;
	}// Of getAccuracy

	// 获得最近的 k 个邻居放入数组中
	public int[] computeNearests(int paraCurrent) {
		int[] resultNearests = new int[numNeighbors];
		boolean[] tempSelected = new boolean[trainingSet.length];
		double tempDistance;
		double tempMinimalDistance;
		int tempMinimalIndex = 0;

		// 借用循环找出距离最短的邻居下标
		for (int i = 0; i < numNeighbors; i++) {
			
			// 需要找的第 i 个邻居(共 K 个)
			tempMinimalDistance = Double.MAX_VALUE;
			
			// 从训练场地比较筛选
			for (int j = 0; j < trainingSet.length; j++) {
				
				// 如果这个结点被确定了最近结点中的一个,索引对应值true,跳过比较
				if (tempSelected[j]) {
					continue;
				} // Of if

				tempDistance = distance(paraCurrent, trainingSet[j]);
				if (tempDistance < tempMinimalDistance) {
					tempMinimalDistance = tempDistance;
					tempMinimalIndex = j;
				} // Of if
			} // Of for j

			resultNearests[i] = trainingSet[tempMinimalIndex];
			tempSelected[tempMinimalIndex] = true;
		} // Of for i

		System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
		return resultNearests;
	}// Of computeNearests

	// 进行投票比较
	public int simpleVoting(int[] paraNeighbors) {
		int[] tempVotes = new int[dataset.numClasses()];
		// 投票
		for (int i = 0; i < paraNeighbors.length; i++) {
			tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
		} // Of for i

		int tempMaximalVotingIndex = 0;
		int tempMaximalVoting = 0;
		// 比较找出票数最多的
		for (int i = 0; i < dataset.numClasses(); i++) {
			if (tempVotes[i] > tempMaximalVoting) {
				tempMaximalVoting = tempVotes[i];
				tempMaximalVotingIndex = i;
			} // Of if
		} // Of for i

		return tempMaximalVotingIndex;
	}// Of simpleVoting

	// 测试程序入口
	public static void main(String args[]) {
		KnnClassification tempClassifier = new KnnClassification("D:/data/iris.arff");
		tempClassifier.splitTrainingTesting(0.8);
		tempClassifier.predict();
		System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
	}// Of main

}// Of class KnnClassification

给出了详细的代码,其中的要点已经最大化的详细了。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值