算法介绍
- 在特征空间中统计k个距离最近的样本的标签,选择最多的标签最为自己的标签。
- 可以采用多种距离计算策略,如曼哈顿距离、欧氏距离。
- KNN是非参且惰性的。
- 优点:实现简单、训练快(惰性)、效果好、对异常值不敏感
- 缺点:时空复杂度都高、需要合适的归一化等
算法流程
变量准备
在这里准备了两种距离策略、训练集、测试集、验证集和k值。
//两种距离
public static final int MANHATTAN = 0;
public static final int EUCLIDEAN = 1;
//距离策略
public int distanceMeasure = EUCLIDEAN;
//设置随机种子
public static final Random random = new Random();
//设置k值
int numNeighbors = 7;
//数据集
Instances dataset;
//训练集、测试集、验证集数组
int[] trainingSet;
int[] testingSet;
int[] predictions;
用Instances类读取数据集
使用构造方法,读取数据集并存入Instances类中。
/**
* 构造方法,用Instances类读取数据集
* @param paraFilename 数据集地址
*/
public Knn(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
//这里用Instances读取数据集
dataset = new Instances(fileReader);
//将当前Istances类的标签下标设置为(数据集的属性数-1)?
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
打乱数据集并划分训练集与测试集
在这里将原本的数据集随机交换打乱,再按照比例将数据集划分为训练集与测试集。
/**
* 获得一个随机序列
* @param paraLength 序列长度
* @return 打乱后的序列
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
//按序号赋值
for(int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
}
//随机交换打乱
int tempFirst, tempSecond, tempValue;
for(int i = 0; i < paraLength; i++) {
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempFirst] = tempValue;
}
return resultIndices;
}
/**
* 划分训练集与验证集,输入训练集的占比,将数据集的下标放入trainingSet和testingSet中
* @param paraTrainingFraction 训练集比例
*/
public void splitTrainingTesting(double paraTrainingFraction) {
//获得数据集记录数
int tempSize = dataset.numInstances();
//得到长度为tempSize的随机序列
int[] tempIndices = getRandomIndices(tempSize);
//获得训练集记录数
int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
//声明训练集与测试集
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
//为训练集与测试集赋值
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
}
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
}
}
距离计算
这里规定了两种距离,并实现了两种距离的计算。
/**
* 计算两条记录之间对应的距离
* @param paraI 一个记录下标
* @param paraJ 另一条记录下标
* @return 距离结果
*/
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
//两种距离策略分开计算
switch (distanceMeasure) {
case MANHATTAN:
//对除去标签以外的所有属性进行处理
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
//归一化处理?
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
}
}
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
}
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}
return resultDistance;
}
得到k个最近记录
该部分是输入当前记录的下标,计算距离最近的k个记录,并返回。
/**
* 得到距离最近的k个记录,存放在数组中
* @param paraCurrent 当前记录的下标
* @return 返回k个最近记录
*/
public int[] computeNearests(int paraCurrent) {
//声明空间为k的数组
int[] resultNearests = new int[numNeighbors];
//声明长度同训练集的数组,作用是判断对应训练集记录是否已经被加入resultNearests数组
boolean[] tempSelected = new boolean[trainingSet.length];
double tempMinimalDistance;
int tempMinimalIndex = 0;
//
double[] tempDistances = new double[trainingSet.length];
for (int i = 0; i < trainingSet.length; i ++) {
tempDistances[i] = distance(paraCurrent, trainingSet[i]);
}
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
}
if (tempDistances[j] < tempMinimalDistance) {
tempMinimalDistance = tempDistances[j];
tempMinimalIndex = j;
}
}
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
}
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}
对单条记录进行预测
该部分是输入一条记录,利用得到的k个最近记录,实现预测功能。
/**
* 对单条记录进行预测
* @param paraNeighbors k个最近记录
* @return 返回预测得到的标签
*/
public int simpleVoting(int[] paraNeighbors) {
//以数据集的标签个数建立数组,目的是统计对于当前记录而言,最接近哪个标签
int[] tempVotes = new int[dataset.numClasses()];
//循环统计最近的k个记录,将k个记录的所属标签计入数组
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
}
//选择最多的标签作为自己的标签
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
}
}
//返回所属标签的下标
return tempMaximalVotingIndex;
}
/**
* 预测单条记录的标签
* @param paraIndex 单条记录的下标
* @return 返回标签值的序号
*/
public int predict(int paraIndex) {
//得到距离最近的k个记录
int[] tempNeighbors = computeNearests(paraIndex);
//投票得到k个记录中最多的标签值
int resultPrediction = simpleVoting(tempNeighbors);
//返回预测标签的序号
return resultPrediction;
}
对测试集进行预测存入验证集
利用单条记录预测功能实现全部测试集的预测功能。
/**
* 将测试集的预测结果放在验证集中
*/
public void predict() {
//为验证集赋予空间
predictions = new int[testingSet.length];
//将测试集的预测结果放在验证集中
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
}
}
计算准确率
根据验证集的结果与真实结果进行比对,计算模型预测准确率。
/**
* 根据验证集的结果与真实结果进行比对,计算模型预测准确率
* @return 预测准确率
*/
public double getAccuracy() {
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
}
}
return tempCorrect / testingSet.length;
}
详细注释
package knn_nb;
import weka.core.Instances;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
public class Knn {
//两种距离
public static final int MANHATTAN = 0;
public static final int EUCLIDEAN = 1;
//距离策略
public int distanceMeasure = EUCLIDEAN;
//设置随机种子
public static final Random random = new Random();
//设置k值
int numNeighbors = 7;
//数据集
Instances dataset;
//训练集、测试集、验证集数组
int[] trainingSet;
int[] testingSet;
int[] predictions;
/**
* 构造方法,用Instances类读取数据集
* @param paraFilename 数据集地址
*/
public Knn(String paraFilename) {
try {
FileReader fileReader = new FileReader(paraFilename);
//这里用Instances读取数据集
dataset = new Instances(fileReader);
//将当前Istances类的标签下标设置为(数据集的属性数-1)?
dataset.setClassIndex(dataset.numAttributes() - 1);
fileReader.close();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 获得一个随机序列
* @param paraLength 序列长度
* @return 打乱后的序列
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
//按序号赋值
for(int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
}
//随机交换打乱
int tempFirst, tempSecond, tempValue;
for(int i = 0; i < paraLength; i++) {
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempFirst] = tempValue;
}
return resultIndices;
}
/**
* 划分训练集与验证集,输入训练集的占比,将数据集的下标放入trainingSet和testingSet中
* @param paraTrainingFraction 训练集比例
*
*/
public void splitTrainingTesting(double paraTrainingFraction) {
//获得数据集记录数
int tempSize = dataset.numInstances();
//得到长度为tempSize的随机序列
int[] tempIndices = getRandomIndices(tempSize);
//获得训练集记录数
int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
//声明训练集与测试集
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
//为训练集与测试集赋值
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
}
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
}
}
/**
* 将测试集的预测结果放在验证集中
*/
public void predict() {
//为验证集赋予空间
predictions = new int[testingSet.length];
//将测试集的预测结果放在验证集中
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
}
}
/**
* 预测单条记录的标签
* @param paraIndex 单条记录的下标
* @return 返回标签值的序号
*/
public int predict(int paraIndex) {
//得到距离最近的k个记录
int[] tempNeighbors = computeNearests(paraIndex);
//投票得到k个记录中最多的标签值
int resultPrediction = simpleVoting(tempNeighbors);
//返回预测标签的序号
return resultPrediction;
}
/**
* 计算两条记录之间对应的距离
* @param paraI 一个记录下标
* @param paraJ 另一条记录下标
* @return 距离结果
*/
public double distance(int paraI, int paraJ) {
double resultDistance = 0;
double tempDifference;
//两种距离策略分开计算
switch (distanceMeasure) {
case MANHATTAN:
//对除去标签以外的所有属性进行处理
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
//归一化处理?
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
}
}
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
resultDistance += tempDifference * tempDifference;
}
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}
return resultDistance;
}
/**
* 根据验证集的结果与真实结果进行比对,计算模型预测准确率
* @return 预测准确率
*/
public double getAccuracy() {
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
}
}
return tempCorrect / testingSet.length;
}
/**
* 得到距离最近的k个记录,存放在数组中
* @param paraCurrent 当前记录的下标
* @return 返回k个最近记录
*/
public int[] computeNearests(int paraCurrent) {
//声明空间为k的数组
int[] resultNearests = new int[numNeighbors];
//声明长度同训练集的数组,作用是判断对应训练集记录是否已经被加入resultNearests数组
boolean[] tempSelected = new boolean[trainingSet.length];
double tempMinimalDistance;
int tempMinimalIndex = 0;
//
double[] tempDistances = new double[trainingSet.length];
for (int i = 0; i < trainingSet.length; i ++) {
tempDistances[i] = distance(paraCurrent, trainingSet[i]);
}
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
}
if (tempDistances[j] < tempMinimalDistance) {
tempMinimalDistance = tempDistances[j];
tempMinimalIndex = j;
}
}
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
}
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}
/**
* 对单条记录进行预测
* @param paraNeighbors k个最近记录
* @return 返回预测得到的标签
*/
public int simpleVoting(int[] paraNeighbors) {
//以数据集的标签个数建立数组,目的是统计对于当前记录而言,最接近哪个标签
int[] tempVotes = new int[dataset.numClasses()];
//循环统计最近的k个记录,将k个记录的所属标签计入数组
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
}
//选择最多的标签作为自己的标签
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
}
}
//返回所属标签的下标
return tempMaximalVotingIndex;
}
/**
* 主函数,启动预测并展示准确率
* @param args
*/
public static void main(String args[]) {
Knn knn = new Knn("C:\\Users\\hp\\Desktop\\deepLearning\\src\\main\\java\\knn_nb\\iris.arff");
knn.splitTrainingTesting(0.8);
knn.predict();
System.out.println("The accuracy of the classifier is: " + knn.getAccuracy());
}
}