这个代码 300 行, 分三天完成. 今天先把代码抄完并运行, 明后天有修改程序的工作. 要求熟练掌握.
1.两种距离度量.
2.数据随机分割方式.
3.间址的灵活使用: trainingSet 和 testingSet 都是整数数组, 表示下标.
4.arff 文件的读取. 需要 weka.jar 包.
5.求邻居.
下载并安装weka.jar包,加载arff文件,代码如下:
package machinelearning.knn;
import java.util.Arrays;
import java.io.FileReader;
import java.util.Random;
import weka.core.*;
/**
* @description: knn分类器
* @author: Qing Zhang
* @time: 2021/7/1
*/
public class KnnClassification {
//曼哈顿距离,|x|+|y|
public static final int MANHATTAN = 0;
//欧氏距离
public static final int EUCLIDEAN = 1;
//距离衡量方式
public int distanceMeasure = EUCLIDEAN;
//一个随机实例
public static final Random random = new Random();
//邻居数量
int numNeighbors = 7;
//存储整个数据集
Instances dataset;
//训练集。由数据索引表示
int[] trainingSet;
//测试集。由数据索引表示
int[] testingSet;
//预测结果
int[] predictions;
public KnnClassification(String paraFileName) {
try {
FileReader fileReader = new FileReader((paraFileName));
dataset = new Instances(fileReader);
//最后一个属性是类别
dataset.setClassIndex((dataset.numAttributes() - 1));
fileReader.close();
} catch (Exception e) {
System.out.println("Error occurred while trying to read \'" + paraFileName
+ "\' in KnnClassification constructor.\r\n" + e);
System.exit(0);
}
}
/**
* @Description: 获得一个随机的索引用于数据随机化
* @Param: [paraLength: 序列的长度]
* @return: int[] e.g., {4, 3, 1, 5, 0, 2} with length 6.
*/
public static int[] getRandomIndices(int paraLength) {
int[] resultIndices = new int[paraLength];
//第一步:初始化
for (int i = 0; i < paraLength; i++) {
resultIndices[i] = i;
}
//第二步:随机交换
int tempFirst, tempSecond, tempValue;
for (int i = 0; i < paraLength; i++) {
//产生两个随机的索引
tempFirst = random.nextInt(paraLength);
tempSecond = random.nextInt(paraLength);
//交换
tempValue = resultIndices[tempFirst];
resultIndices[tempFirst] = resultIndices[tempSecond];
resultIndices[tempSecond] = tempValue;
}
return resultIndices;
}
/**
* @Description: 将数据分割为训练集和测试集
* @Param: [paraTrainingFraction:训练集的占比]
* @return: void
*/
public void splitTrainingTesting(double paraTrainingFraction) {
//numInstances??获取数据集的样本数量。
int tempSize = dataset.numInstances();
int[] tempIndices = getRandomIndices(tempSize);
int tempTrainingSize = (int) (tempSize * paraTrainingFraction);
trainingSet = new int[tempTrainingSize];
testingSet = new int[tempSize - tempTrainingSize];
//将前tempTrainingSize的数据赋给测试集
for (int i = 0; i < tempTrainingSize; i++) {
trainingSet[i] = tempIndices[i];
}
//将剩余的数据赋给测试集
for (int i = 0; i < tempSize - tempTrainingSize; i++) {
testingSet[i] = tempIndices[tempTrainingSize + i];
}
}
/**
* @Description: 预测整个测试集。结果都存储再预测结果集中
* @Param: []
* @return: void
*/
public void predict() {
predictions = new int[testingSet.length];
for (int i = 0; i < predictions.length; i++) {
predictions[i] = predict(testingSet[i]);
}
}
/**
* @Description: 预测给定的实例
* @Param: [paraIndex]
* @return: int
*/
public int predict(int paraIndex) {
int[] tempNeighbors = computeNearests(paraIndex);
int resultPrediction = simpleVoting(tempNeighbors);
return resultPrediction;
}
/**
* @Description: 两个实例之间的距离
* @Param: [paraI, paraJ]
* @return: double
*/
public double distance(int paraI, int paraJ) {
int resultDistance = 0;
double tempDifference;
switch (distanceMeasure) {
case MANHATTAN:
//numAttributes??
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
if (tempDifference < 0) {
resultDistance -= tempDifference;
} else {
resultDistance += tempDifference;
}
}
break;
case EUCLIDEAN:
for (int i = 0; i < dataset.numAttributes() - 1; i++) {
tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
//这里为什么直接是平方??应该是为了效率,将开根省略了
resultDistance += tempDifference * tempDifference;
}
break;
default:
System.out.println("Unsupported distance measure: " + distanceMeasure);
}
return resultDistance;
}
/**
* @Description: 获得分类器的正确率
* @Param: []
* @return: double
*/
public double getAccuracy() {
//一个double除以一个int会得到另一个double
double tempCorrect = 0;
for (int i = 0; i < predictions.length; i++) {
//classValue??就是数据集中的序号
if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
tempCorrect++;
}
}
return tempCorrect / testingSet.length;
}
/**
* @Description: 计算最近的k个邻居。在每一轮扫描中选择一个邻居
* @Param: [paraIndex]
* @return: int[]
*/
public int[] computeNearests(int paraCurrent) {
int[] resultNearests = new int[numNeighbors];
boolean[] tempSelected = new boolean[trainingSet.length];
double tempDistance;
double tempMinimalDistance;
int tempMinimalIndex = 0;
//选择最近的k个索引
for (int i = 0; i < numNeighbors; i++) {
tempMinimalDistance = Double.MAX_VALUE;
for (int j = 0; j < trainingSet.length; j++) {
if (tempSelected[j]) {
continue;
}
tempDistance = distance(paraCurrent, trainingSet[j]);
if (tempDistance < tempMinimalDistance) {
tempMinimalDistance = tempDistance;
tempMinimalIndex = j;
}
}
resultNearests[i] = trainingSet[tempMinimalIndex];
tempSelected[tempMinimalIndex] = true;
}
System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
return resultNearests;
}
/**
* @Description: 使用实例投票
* @Param: [tempNeighbors]
* @return: int
*/
public int simpleVoting(int[] paraNeighbors) {
//numClasses?? 计算每一个类出现的次数
int[] tempVotes = new int[dataset.numClasses()];
for (int i = 0; i < paraNeighbors.length; i++) {
tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
}
int tempMaximalVotingIndex = 0;
int tempMaximalVoting = 0;
for (int i = 0; i < dataset.numClasses(); i++) {
if (tempVotes[i] > tempMaximalVoting) {
tempMaximalVoting = tempVotes[i];
tempMaximalVotingIndex = i;
}
}
return tempMaximalVotingIndex;
}
public static void main(String[] args) {
KnnClassification tempClassifier = new KnnClassification("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff");
tempClassifier.splitTrainingTesting(0.8);
tempClassifier.predict();
System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
}
}
运行结果如下: