KNN学习

这篇博客介绍了KNN(K最近邻)算法的基本概念、工作原理和关键步骤,包括距离计算(如欧式距离)以及K值的选择。KNN算法是一种简单而有效的分类方法,适用于分类和回归问题。博主提供了Java实现的KNN分类器代码,包括距离计算、K值选择和交叉验证的过程,并讨论了算法的优缺点。文章通过实例展示了K值变化对分类结果的影响,强调了K值选择的重要性。
摘要由CSDN通过智能技术生成

学习来源:日撸 Java 三百行(51-60天,KNN与NB))_闵帆的博客-CSDN博客
第 51 天: kNN 分类器

一.KNN算法概述
KNN可以说是最简单的分类算法之一,同时,它也是最常用的分类算法之一,注意KNN算法是有监督学习中的分类算法,它看起来和另一个机器学习算法Kmeans有点像(Kmeans是无监督学习算法),但却是有本质区别的。那么什么是KNN算法呢,接下来我们就来介绍介绍吧。

二.KNN算法介绍
KNN的全称是K Nearest Neighbors,意思是K个最近的邻居,从这个名字我们就能看出一些KNN算法的蛛丝马迹了。K个最近邻居,毫无疑问,K的取值肯定是至关重要的。那么最近的邻居又是怎么回事呢?其实啊,KNN的原理就是当预测一个新的值x的时候,根据它距离最近的K个点是什么类别来判断x属于哪个类别。听起来有点绕,还是看看图吧。

在这里插入图片描述

图中绿色的点就是我们要预测的那个点,假设K=3。那么KNN算法就会找到与它距离最近的三个点(这里用圆圈把它圈起来了),看看哪种类别多一些,比如这个例子中是蓝色三角形多一些,新来的绿色点就归类到蓝三角了。
在这里插入图片描述

但是,当K=5的时候,判定就变成不一样了。这次变成红圆多一些,所以新来的绿点被归类成红圆。从这个例子中,我们就能看得出K的取值是很重要的。

明白了大概原理后,我们就来说一说细节的东西吧,主要有两个,K值的选取和点距离的计算。

2.1距离计算
要度量空间中点距离的话,有好几种度量方式,比如常见的曼哈顿距离计算,欧式距离计算等等。不过通常KNN算法中使用的是欧式距离,这里只是简单说一下,拿二维平面为例,,二维空间两个点的欧式距离计算公式如下:
在这里插入图片描述

这个高中应该就有接触到的了,其实就是计算(x1,y1)和(x2,y2)的距离。拓展到多维空间,则公式变成这样:
在这里插入图片描述

这样我们就明白了如何计算距离,KNN算法最简单粗暴的就是将预测点与所有点距离进行计算,然后保存并排序,选出前面K个值看看哪些类别比较多。但其实也可以通过一些数据结构来辅助,比如最大堆,这里就不多做介绍,有兴趣可以百度最大堆相关数据结构的知识。

2.2 K值选择
通过上面那张图我们知道K的取值比较重要,那么该如何确定K取多少值好呢?答案是通过交叉验证(将样本数据按照一定比例,拆分出训练用的数据和验证用的数据,比如6:4拆分出部分训练数据和验证数据),从选取一个较小的K值开始,不断增加K的值,然后计算验证集合的方差,最终找到一个比较合适的K值。

通过交叉验证计算方差后你大致会得到下面这样的图:

在这里插入图片描述

这个图其实很好理解,当你增大k的时候,一般错误率会先降低,因为有周围更多的样本可以借鉴了,分类效果会变好。但注意,和K-means不一样,当K值更大的时候,错误率会更高。这也很好理解,比如说你一共就35个样本,当你K增大到30的时候,KNN基本上就没意义了。

所以选择K点的时候可以选择一个较大的临界K点,当它继续增大或减小的时候,错误率都会上升,比如图中的K=10。
三.kNN 的特点:
1.简单. 没有学习过程, 也被称为惰性学习 lazy learning. 类似于开卷考试, 在已有数据中去找答案.
2.本源. 找相似, 正是人类认识事物的常用方法, 隐藏于人类或者其他动物的基因里面. 当然, 人类也会上当, 例如有人把邻居的滴水观音误认为是芋头, 偷食后中毒.
3.效果好. 永远不要小视 kNN, 对于很多数据, 你很难设计算法超越它.
4.适应性强. 可用于分类, 回归. 可用于各种数据.
5.可扩展性强. 设计不同的度量, 可获得意想不到的效果.
6.一般需要对数据归一化.
7.复杂度高. 这也是 kNN 最重要的缺点. 对于每一个测试数据, 复杂度为 O((m+k)n), 其中n 为训练数据个数, m 为条件属性个数, k 为邻居个数.

ps:简单得说,当需要使用分类算法,且数据比较大的时候就可以尝试使用KNN算法进行分类了。

代码:

package machinelearning.knn;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import weka.core.*;

/**
 * kNN classification.
 * 
 * @author Fan Min minfanphd@163.com.
 */
public class KnnClassification {

	/**
	 * Manhattan distance.
	 */
	public static final int MANHATTAN = 0;

	/**
	 * Euclidean distance.
	 */
	public static final int EUCLIDEAN = 1;

	/**
	 * The distance measure.
	 */
	public int distanceMeasure = EUCLIDEAN;

	/**
	 * A random instance;
	 */
	public static final Random random = new Random();

	/**
	 * The number of neighbors.
	 */
	int numNeighbors = 7;

	/**
	 * The whole dataset.
	 */
	Instances dataset;

	/**
	 * The training set. Represented by the indices of the data.
	 */
	int[] trainingSet;

	/**
	 * The testing set. Represented by the indices of the data.
	 */
	int[] testingSet;

	/**
	 * The predictions.
	 */
	int[] predictions;

	/**
	 *********************
	 * The first constructor.
	 * 
	 * @param paraFilename
	 *            The arff filename.
	 *********************
	 */
	public KnnClassification(String paraFilename) {
		try {
			FileReader fileReader = new FileReader(paraFilename);
			dataset = new Instances(fileReader);
			// The last attribute is the decision class.
			dataset.setClassIndex(dataset.numAttributes() - 1);
			fileReader.close();
		} catch (Exception ee) {
			System.out.println("Error occurred while trying to read \'" + paraFilename
					+ "\' in KnnClassification constructor.\r\n" + ee);
			System.exit(0);
		} // Of try
	}// Of the first constructor

	/**
	 *********************
	 * Get a random indices for data randomization.
	 * 
	 * @param paraLength
	 *            The length of the sequence.
	 * @return An array of indices, e.g., {4, 3, 1, 5, 0, 2} with length 6.
	 *********************
	 */
	public static int[] getRandomIndices(int paraLength) {
		int[] resultIndices = new int[paraLength];

		// Step 1. Initialize.
		for (int i = 0; i < paraLength; i++) {
			resultIndices[i] = i;
		} // Of for i

		// Step 2. Randomly swap.
		int tempFirst, tempSecond, tempValue;
		for (int i = 0; i < paraLength; i++) {
			// Generate two random indices.
			tempFirst = random.nextInt(paraLength);
			tempSecond = random.nextInt(paraLength);

			// Swap.
			tempValue = resultIndices[tempFirst];
			resultIndices[tempFirst] = resultIndices[tempSecond];
			resultIndices[tempSecond] = tempValue;
		} // Of for i

		return resultIndices;
	}// Of getRandomIndices

	/**
	 *********************
	 * Split the data into training and testing parts.
	 * 
	 * @param paraTrainingFraction
	 *            The fraction of the training set.
	 *********************
	 */
	public void splitTrainingTesting(double paraTrainingFraction) {
		int tempSize = dataset.numInstances();
		int[] tempIndices = getRandomIndices(tempSize);
		int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

		trainingSet = new int[tempTrainingSize];
		testingSet = new int[tempSize - tempTrainingSize];

		for (int i = 0; i < tempTrainingSize; i++) {
			trainingSet[i] = tempIndices[i];
		} // Of for i

		for (int i = 0; i < tempSize - tempTrainingSize; i++) {
			testingSet[i] = tempIndices[tempTrainingSize + i];
		} // Of for i
	}// Of splitTrainingTesting

	/**
	 *********************
	 * Predict for the whole testing set. The results are stored in predictions.
	 * #see predictions.
	 *********************
	 */
	public void predict() {
		predictions = new int[testingSet.length];
		for (int i = 0; i < predictions.length; i++) {
			predictions[i] = predict(testingSet[i]);
		} // Of for i
	}// Of predict

	/**
	 *********************
	 * Predict for given instance.
	 * 
	 * @return The prediction.
	 *********************
	 */
	public int predict(int paraIndex) {
		int[] tempNeighbors = computeNearests(paraIndex);
		int resultPrediction = simpleVoting(tempNeighbors);

		return resultPrediction;
	}// Of predict

	/**
	 *********************
	 * The distance between two instances.
	 * 
	 * @param paraI
	 *            The index of the first instance.
	 * @param paraJ
	 *            The index of the second instance.
	 * @return The distance.
	 *********************
	 */
	public double distance(int paraI, int paraJ) {
		double resultDistance = 0;
		double tempDifference;
		switch (distanceMeasure) {
		case MANHATTAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				if (tempDifference < 0) {
					resultDistance -= tempDifference;
				} else {
					resultDistance += tempDifference;
				} // Of if
			} // Of for i
			break;

		case EUCLIDEAN:
			for (int i = 0; i < dataset.numAttributes() - 1; i++) {
				tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
				resultDistance += tempDifference * tempDifference;
			} // Of for i
			break;
		default:
			System.out.println("Unsupported distance measure: " + distanceMeasure);
		}// Of switch

		return resultDistance;
	}// Of distance

	/**
	 *********************
	 * Get the accuracy of the classifier.
	 * 
	 * @return The accuracy.
	 *********************
	 */
	public double getAccuracy() {
		// A double divides an int gets another double.
		double tempCorrect = 0;
		for (int i = 0; i < predictions.length; i++) {
			if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
				tempCorrect++;
			} // Of if
		} // Of for i

		return tempCorrect / testingSet.length;
	}// Of getAccuracy

	/**
	 ************************************
	 * Compute the nearest k neighbors. Select one neighbor in each scan. In
	 * fact we can scan only once. You may implement it by yourself.
	 * 
	 * @param paraK
	 *            the k value for kNN.
	 * @param paraCurrent
	 *            current instance. We are comparing it with all others.
	 * @return the indices of the nearest instances.
	 ************************************
	 */
	public int[] computeNearests(int paraCurrent) {
		int[] resultNearests = new int[numNeighbors];
		boolean[] tempSelected = new boolean[trainingSet.length];
		double tempMinimalDistance;
		int tempMinimalIndex = 0;

		// Compute all distances to avoid redundant computation.
		double[] tempDistances = new double[trainingSet.length];
		for (int i = 0; i < trainingSet.length; i ++) {
			tempDistances[i] = distance(paraCurrent, trainingSet[i]);
		}//Of for i
		
		// Select the nearest paraK indices.
		for (int i = 0; i < numNeighbors; i++) {
			tempMinimalDistance = Double.MAX_VALUE;

			for (int j = 0; j < trainingSet.length; j++) {
				if (tempSelected[j]) {
					continue;
				} // Of if

				if (tempDistances[j] < tempMinimalDistance) {
					tempMinimalDistance = tempDistances[j];
					tempMinimalIndex = j;
				} // Of if
			} // Of for j

			resultNearests[i] = trainingSet[tempMinimalIndex];
			tempSelected[tempMinimalIndex] = true;
		} // Of for i

		System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
		return resultNearests;
	}// Of computeNearests

	/**
	 ************************************
	 * Voting using the instances.
	 * 
	 * @param paraNeighbors
	 *            The indices of the neighbors.
	 * @return The predicted label.
	 ************************************
	 */
	public int simpleVoting(int[] paraNeighbors) {
		int[] tempVotes = new int[dataset.numClasses()];
		for (int i = 0; i < paraNeighbors.length; i++) {
			tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
		} // Of for i

		int tempMaximalVotingIndex = 0;
		int tempMaximalVoting = 0;
		for (int i = 0; i < dataset.numClasses(); i++) {
			if (tempVotes[i] > tempMaximalVoting) {
				tempMaximalVoting = tempVotes[i];
				tempMaximalVotingIndex = i;
			} // Of if
		} // Of for i

		return tempMaximalVotingIndex;
	}// Of simpleVoting

	/**
	 *********************
	 * The entrance of the program.
	 * 
	 * @param args
	 *            Not used now.
	 *********************
	 */
	public static void main(String args[]) {
		KnnClassification tempClassifier = new KnnClassification("D:/data/iris.arff");
		tempClassifier.splitTrainingTesting(0.8);
		tempClassifier.predict();
		System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
	}// Of main

}// Of class KnnClassification
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值