简介:
- 全名K-nearest neighber 算法,它是一个根据训练集并且非规则的分类算法。
- Cover 和 Hart 在 1968年提出的最初的邻近算法
- 通过测量不同特征值之间的距离方法进行分类
- 属于懒惰学习方法Lazy learner
- Eager Learners:接收测试数据之前,对已有的数据构建一个分类模型。
- Lazy Learners (instance-based learners) : 在训练阶段仅仅把样本保存下来,训练时间开销为零,待接收到测试样本后在进行处理。
算法原理:
存在一个样本集合(训练样本集),并且样本中的每个数据都有标签(也是类别)。
输入没有类别标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后提取样本集最相似的K个分类标签。
选择这个K个标签中次数最多的分类为新数据的分类。
举例:
五列数据分别对应花萼长度、花萼宽度、花瓣长度、花瓣宽度和种类,其中种类分别为山鸢尾、变色鸢尾和维吉尼亚鸢尾三个类别。
5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica
我在这里只选用了每种花的四个数据样本,实际上数据样本可以非常的大。为了方便起见我把它们画在一个平面上(距离可能不准,请理解思路)。
然后有一个测试数据进来:5.0,3.0,1.6,0.2 不知道它是那种花,因此我们给定一个k值等于三,找出他周围最相邻的三个数据。图中黑球像这样:
我们直观的看出当 K = 3 时,其中有两个红球,一个蓝球,也就是说我们的待测数据跟红球样本相似度更大,所以判定此花是Iris-setosa。
问题来了:K值是我们给定的,哪我们怎样去判断距离呢?
第一步:明确距离坐标的含义,就是事物的特征值。比如此处的花萼长度、花萼宽度、花瓣长度、花瓣宽度就是坐标中的 x,y,z. 多了画不出来没关系,下面我们将介绍它的距离计算。
第二步:距离计算,我介绍常用的两种。
欧式距离:大家再熟悉熟悉不过了,例如:x=(x1,x2,...,xn) , y = (y1,y2,...,yn)
曼哈顿距离:x = (x1,x2) , y = (y1,y2) d = | x1 - y1 | + | x2 - y2 |
图中橘色的线段便是欧氏距离,而紫色的便是曼哈顿距离。两种距离各有千秋,像我们这道题:给定花的特征判断花的种类,欧式距离比较好,如果是城乡规划路径上的问题曼哈顿距离优势更准确,因为你不能直接越过楼房算位移,我们需要路程。
投票排序:
我们找出了 3 个距待测样本最近的数据后就要进行投票了。比如 第一个球是红色,红球+1,第二球是蓝色,蓝球+1,第三个球是红色红球+1。最终的结果是 :红(2票),蓝(1票)。代码中体现的是花的种类对应的数组下标。
根据票数我们就知道测试的结果是红球。可这样还不行,这只是预测,我们还需要来估算预测的准确率是多少,它与我们的K选取有关。比如:100个未知数据,我们预测的结果和真实的结果相比较,发现对了99个,那么预测的准确率便是 0.99.
代码:
package machinelearning.knn;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import weka.core.*;
public class KnnClassification {
/**
* 曼哈顿距离
*/
public static final int MANHATTAN = 0;
/**
* 欧式距离
*/
public static final int EUCLIDEAN = 1;
/**
* 距离测量
*/
public int distanceMeasure = EUCLIDEAN; //现将距离测量置为1
/**
* 生成随机距离
*/
public static final Random random = new Random();
/**
* 邻居数目
*/
int numNeighbors = 7;
/**
* The whole dataset.
*/
Instances dataset;
/**
* 训练场地,由数据索引表示
*/
int[] trainingSet;
/**
* 测试集,由数据索引表示
*/
int[] testingSet;
/**
* 预测
*/
int[] predictions;
// 导入数据
public KnnClassification(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
dataset = new Instances(fileReader);
// The last attribute is the decision class.
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (Exception ee) {
System.out.println("Error occurred while trying to read \'" + paraFilename
+ "\' in KnnClassification constructor.\r\n" + ee);
System.exit(0);
} // Of try
}// Of the first constructor
// 获得随机索引,数据随机化
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
// Step 1. 初始化resultIndices数组
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
} // Of for i
// Step 2. 随机交换
int tempFirst, tempSecond, tempValue;
// for 循环的目的是打乱数组中数字的顺序
for (int i = 0; i < paraLength; i++) {
// 生成两个随机索引,tempFirst,tempSecond
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
// Swap. 交换
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
} // Of for i
return resultIndices;
}// Of getRandomIndices
// 把导入的数据一分为二,训练场地和测试场地
public void splitTrainingTesting(double paraTrainingFraction) {
//
int tempSize = dataset.numInstances();
//得到一个测试数组
int[] tempIndices = getRandomIndices(tempSize);
int tempTrainingSize = (int) (tempSize * paraTrainingFraction);
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
} // Of for i
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
} // Of for i
}// Of splitTrainingTesting
// 把预测的结果放在predictions数组中集中起来,后面做准确率计算
public void predict() {
predictions = new int[testingSet.length];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
} // Of for i
}// Of predict
// 进行预测
public int predict(int paraIndex) {
// 1. 找出相邻的 K 个邻居
int[] tempNeighbors = computeNearests(paraIndex);
// 2. 进行投票找出种类标签(数组下标表示,所以返回的是一个int型)
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}// Of predict
// 距离测量,两种选择,曼哈顿,欧式距离
public double distance(int paraI, int paraJ) { // 猜测I和J是两个
double resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
} // Of if
} // Of for i
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
} // Of for i
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}// Of switch
return resultDistance;
}// Of distance
// 精确度
public double getAccuracy() {
double tempCorrect = 0;
// 把正确的预测统计出来
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
} // Of if
} // Of for i
return tempCorrect / testingSet.length;
}// Of getAccuracy
// 获得最近的 k 个邻居放入数组中
public int[] computeNearests(int paraCurrent) {
int[] resultNearests = new int[numNeighbors];
boolean[] tempSelected = new boolean[trainingSet.length];
double tempDistance;
double tempMinimalDistance;
int tempMinimalIndex = 0;
// 借用循环找出距离最短的邻居下标
for (int i = 0; i < numNeighbors; i++) {
// 需要找的第 i 个邻居(共 K 个)
tempMinimalDistance = Double.MAX_VALUE;
// 从训练场地比较筛选
for (int j = 0; j < trainingSet.length; j++) {
// 如果这个结点被确定了最近结点中的一个,索引对应值true,跳过比较
if (tempSelected[j]) {
continue;
} // Of if
tempDistance = distance(paraCurrent, trainingSet[j]);
if (tempDistance < tempMinimalDistance) {
tempMinimalDistance = tempDistance;
tempMinimalIndex = j;
} // Of if
} // Of for j
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
} // Of for i
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}// Of computeNearests
// 进行投票比较
public int simpleVoting(int[] paraNeighbors) {
int[] tempVotes = new int[dataset.numClasses()];
// 投票
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
} // Of for i
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
// 比较找出票数最多的
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
} // Of if
} // Of for i
return tempMaximalVotingIndex;
}// Of simpleVoting
// 测试程序入口
public static void main(String args[]) {
KnnClassification tempClassifier = new KnnClassification("D:/data/iris.arff");
tempClassifier.splitTrainingTesting(0.8);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}// Of main
}// Of class KnnClassification
给出了详细的代码,其中的要点已经最大化的详细了。