前两天刚听完闵帆老师的讲座,赶紧趁着热乎开撸。
KNN是机器学习里面的入门基础算法之一,但是它的普适性很强,对于新的问题,把KNN拿出来缝缝补补改改它又能战斗了,所以可以把它当做算法检测标杆
KNN思想(人类的比较思维):要判断一个未知的事物,可以找一个我们知道并且与之最相似的事物,我们就认为它俩是同一种事物。那么具体落到计算机上要怎么实现呢?
其最主要的就是要模拟找相似的过程,对于输入的一个向量,可以考虑衡量它与已知数据的距离,如果距离值越小,就认为两组数组很接近。
距离的计算方式常用有两种,曼哈顿距离与欧几里得距离(这两名字听起来特别离谱,实际上就是初中学的),接着我们找到与之距离最近的K个邻居,K个邻居投票决定待测样本的类别。K的取值对结果的影响还是比较大的,取小了过拟合,取大了欠拟合(这个地方有点迷糊),K尽量取奇数可以避免投票出现平票的尴尬情况。
package com.trian;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
import myself.test;
import weka.core.*;
/**
* kNN classification.
*
*/
public class KnnClassification {
/**
* Manhattan distance.
*/
public static final int MANHATTAN = 0;
/**
* Euclidean distance.
*/
public static final int EUCLIDEAN = 1;
/**
* The default distance measure.
*/
public int distanceMeasure = EUCLIDEAN;
/**
* A random instance;
*/
public static final Random random = new Random();
/**
* The Number of reference neighbors
*/
int numNeighbors=7;
/**
* The dataset
*/
Instances dataset;
/**
* The trainingSet
* The TestingSet
* The predictions
*/
int []trainingSet;
int []testingSet;
int []predictions;
/*
*********************
* @param paraFilename
* The arff filename.
* @return null
*
*********************
*/
public KnnClassification(String paraFilename){
try{
FileReader fileReader = new FileReader(paraFilename);
dataset=new Instances(fileReader);
dataset.setClassIndex(dataset.numAttributes()-1);
fileReader.close();
}catch(Exception ee){
System.out.println("404 NOT Found");
System.exit(0);
}//of try
}//of KnnClassification
/*
*********************
* @param paraLength
* The The length of the sequence.
* @return int[]
*
*********************
*/
public static int[] getRandomIndices(int paraLength){
int[] resultIndices = new int[paraLength];
for(int i=0;i<paraLength;++i){
resultIndices[i]=i;
}// Of for i
int tempFirst, tempSecond, tempValue;
for(int i=0;i<paraLength;++i){
tempFirst=random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
}//of for i
return resultIndices;
}//of getRandomIndices
/*
*********************
* @param paraTrainingFraction
* The fraction of the training set.
* @return void
*
*********************
*/
public void splitTrainingTesting(double paraTrainingFraction){
int [] tempIndices=getRandomIndices(dataset.numInstances());
int tempTrainingSize=(int)paraTrainingFraction*dataset.numInstances();
trainingSet = new int[tempTrainingSize];
testingSet = new int[dataset.numInstances() - tempTrainingSize];
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
} // Of for i
for (int i = 0; i < dataset.numInstances() - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
} // Of for i
}//of splitTrainingTesting
/*
*********************
* @return void
*********************
*/
public void predict(){
predictions=new int[testingSet.length];
for(int i=0;i<predictions.length;++i){
predictions[i]=predict(testingSet[i]);
}//of for i
}//of predict
/*
*********************
* @param paraIndex
* Predict for given instance.
* @return int
*
*********************
*/
public int predict(int paraIndex){
int[] tempNeighbors = computeNearests(paraIndex);
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}// of predict
public int[] computeNearests(int paraCurrent) {
int[] resultNearests = new int[numNeighbors];
boolean[] tempSelected = new boolean[trainingSet.length];
double tempMinimalDistance;
int tempMinimalIndex = 0;
double[] tempDistances = new double[trainingSet.length];
for (int i = 0; i < trainingSet.length; ++i) {
tempDistances[i] = distance(paraCurrent, trainingSet[i]);
}//of for i
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
} // Of if
if (tempDistances[j] < tempMinimalDistance) {
tempMinimalDistance = tempDistances[j];
tempMinimalIndex = j;
} // Of if
} // Of for j
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
}//of for i
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}// of computeNearests
public double distance(int para1,int para2){
double resultDistance = 0;
double tempDifference;
switch (distanceMeasure){
case MANHATTAN:
for(int i=0;i<dataset.numAttributes()-1;++i){
tempDifference=dataset.instance(para1).value(i)-dataset.instance(para2).value(i);
Math.abs(tempDifference);
resultDistance+=tempDifference;
}//of for i
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++){
tempDifference=dataset.instance(para1).value(i)-dataset.instance(para2).value(i);
resultDistance+=tempDifference*tempDifference;
}//of for i
}//of switch
return resultDistance;
}// of distance
public int simpleVoting(int []paraNeighbors){
int[] tempVotes = new int[dataset.numClasses()];
for(int i=0;i<numNeighbors;++i){
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
}//of for i
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
} // Of if
} // Of for i
return tempMaximalVotingIndex;
}//of simpleVoting
public double getAccuracy() {
double tempCorrect = 0;
for(int i=0;i<predictions.length;++i){
if(predictions[i]==(int)dataset.instance(testingSet[i]).classValue()){
tempCorrect++;
} // Of if
} // Of for i
return tempCorrect/testingSet.length;
}// of getAccuracy
public static void main(String args[]) {
test tempClassifier = new test("C:/Users/胡来的魔术师/Desktop/sampledata-main/iris.arff");
tempClassifier.splitTrainingTesting(0.95);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}// Of main
}// Of class KnnClassification
运行结果:
代码分析:
dataset用于管理arff数据集,数据预处理首先随机打乱数据,然后根据自己的喜好分配训练集与测试集,computeNearests(paraIndex)方法会选出与当前参数最近的k个邻居
simpleVoting(tempNeighbors)方法会依次遍历K个邻居并记录他们的类别数量,最后返回数量最多的那个类别。
distance(int para1,int para2)计算两组数据间的距离。
总结:knn算法很容易理解,感觉也比较暴力。理清这个dataset助手后代码读写也会更容易,instance表示行,也就是一行数据,一个数据对象,Attributes表示列,也就是属性。虽然knn是入门的第一个算法,但是也不能小看它,由于它的思想过于简单没有那么多弯弯绕绕,刚接触的时候难免让人觉得离谱,但是仔细一想,我们人类认识新事物好像绝大部分时候也是这么简单。knn没有学习过程,直接在训练集里找“答案”。