Java实现KNN

该文详细介绍了如何使用Java实现KNN(K近邻)算法,包括数据集的读取、训练集与测试集的划分、距离计算(曼哈顿距离和欧氏距离)、获取k个最近邻记录、单条记录预测以及模型的预测准确率计算。代码中包含了实例数据集的读取、随机打乱数据集、预测过程以及评估预测准确性的方法。
摘要由CSDN通过智能技术生成

算法介绍

  1. 在特征空间中统计k个距离最近的样本的标签,选择最多的标签最为自己的标签。
  2. 可以采用多种距离计算策略,如曼哈顿距离、欧氏距离。
  3. KNN是非参且惰性的。
  4. 优点:实现简单、训练快(惰性)、效果好、对异常值不敏感
  5. 缺点:时空复杂度都高、需要合适的归一化等

算法流程

变量准备

在这里准备了两种距离策略、训练集、测试集、验证集和k值。

	//两种距离
    public static final int MANHATTAN = 0;
    public static final int EUCLIDEAN = 1;
    //距离策略
    public int distanceMeasure = EUCLIDEAN;
    //设置随机种子
    public static final Random random = new Random();
    //设置k值
    int numNeighbors = 7;
    //数据集
    Instances dataset;
    //训练集、测试集、验证集数组
    int[] trainingSet;
    int[] testingSet;
    int[] predictions;

用Instances类读取数据集

使用构造方法,读取数据集并存入Instances类中。

    /**
     * 构造方法,用Instances类读取数据集
     * @param paraFilename 数据集地址
     */
    public Knn(String paraFilename) {
        try {
            FileReader fileReader = new FileReader(paraFilename);
            //这里用Instances读取数据集
            dataset = new Instances(fileReader);
            //将当前Istances类的标签下标设置为(数据集的属性数-1)?
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

打乱数据集并划分训练集与测试集

在这里将原本的数据集随机交换打乱,再按照比例将数据集划分为训练集与测试集。

    /**
     * 获得一个随机序列
     * @param paraLength 序列长度
     * @return 打乱后的序列
     */
    public static int[] getRandomIndices(int paraLength) {
        int[] resultIndices = new int[paraLength];
        //按序号赋值
        for(int i = 0; i < paraLength; i++) {
            resultIndices[i] = i;
        }
        //随机交换打乱
        int tempFirst, tempSecond, tempValue;
        for(int i = 0; i < paraLength; i++) {
            tempFirst = random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempFirst] = tempValue;
        }
        return resultIndices;
    }
    /**
     * 划分训练集与验证集,输入训练集的占比,将数据集的下标放入trainingSet和testingSet中
     * @param paraTrainingFraction 训练集比例
     */
    public void splitTrainingTesting(double paraTrainingFraction) {
        //获得数据集记录数
        int tempSize = dataset.numInstances();
        //得到长度为tempSize的随机序列
        int[] tempIndices = getRandomIndices(tempSize);
        //获得训练集记录数
        int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
        //声明训练集与测试集
        trainingSet = new int[tempTrainingSize];
        testingSet = new int[tempSize - tempTrainingSize];
        //为训练集与测试集赋值
        for (int i = 0; i < tempTrainingSize; i++) {
            trainingSet[i] = tempIndices[i];
        }
        for (int i = 0; i < tempSize - tempTrainingSize; i++) {
            testingSet[i] = tempIndices[tempTrainingSize + i];
        }
    }

距离计算

这里规定了两种距离,并实现了两种距离的计算。

    /**
     * 计算两条记录之间对应的距离
     * @param paraI 一个记录下标
     * @param paraJ 另一条记录下标
     * @return 距离结果
     */
    public double distance(int paraI, int paraJ) {
        double resultDistance = 0;
        double tempDifference;
        //两种距离策略分开计算
        switch (distanceMeasure) {
            case MANHATTAN:
                //对除去标签以外的所有属性进行处理
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    //归一化处理?
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    if (tempDifference < 0) {
                        resultDistance -= tempDifference;
                    } else {
                        resultDistance += tempDifference;
                    }
                }
                break;
            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    resultDistance += tempDifference * tempDifference;
                }
                break;
            default:
                System.out.println("Unsupported distance measure: " + distanceMeasure);
        }
        return resultDistance;
    }

得到k个最近记录

该部分是输入当前记录的下标,计算距离最近的k个记录,并返回。

    /**
     * 得到距离最近的k个记录,存放在数组中
     * @param paraCurrent 当前记录的下标
     * @return 返回k个最近记录
     */
    public int[] computeNearests(int paraCurrent) {
        //声明空间为k的数组
        int[] resultNearests = new int[numNeighbors];
        //声明长度同训练集的数组,作用是判断对应训练集记录是否已经被加入resultNearests数组
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempMinimalDistance;
        int tempMinimalIndex = 0;
        //
        double[] tempDistances = new double[trainingSet.length];
        for (int i = 0; i < trainingSet.length; i ++) {
            tempDistances[i] = distance(paraCurrent, trainingSet[i]);
        }
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                }
                if (tempDistances[j] < tempMinimalDistance) {
                    tempMinimalDistance = tempDistances[j];
                    tempMinimalIndex = j;
                }
            }
            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        }
        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }

对单条记录进行预测

该部分是输入一条记录,利用得到的k个最近记录,实现预测功能。

    /**
     * 对单条记录进行预测
     * @param paraNeighbors k个最近记录
     * @return 返回预测得到的标签
     */
    public int simpleVoting(int[] paraNeighbors) {
        //以数据集的标签个数建立数组,目的是统计对于当前记录而言,最接近哪个标签
        int[] tempVotes = new int[dataset.numClasses()];
        //循环统计最近的k个记录,将k个记录的所属标签计入数组
        for (int i = 0; i < paraNeighbors.length; i++) {
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
        }
        //选择最多的标签作为自己的标签
        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            }
        }
        //返回所属标签的下标
        return tempMaximalVotingIndex;
    }

    /**
     * 预测单条记录的标签
     * @param paraIndex 单条记录的下标
     * @return 返回标签值的序号
     */
    public int predict(int paraIndex) {
        //得到距离最近的k个记录
        int[] tempNeighbors = computeNearests(paraIndex);
        //投票得到k个记录中最多的标签值
        int resultPrediction = simpleVoting(tempNeighbors);
        //返回预测标签的序号
        return resultPrediction;
    }

对测试集进行预测存入验证集

利用单条记录预测功能实现全部测试集的预测功能。

    /**
     * 将测试集的预测结果放在验证集中
     */
    public void predict() {
        //为验证集赋予空间
        predictions = new int[testingSet.length];
        //将测试集的预测结果放在验证集中
        for (int i = 0; i < predictions.length; i++) {
            predictions[i] = predict(testingSet[i]);
        }
    }

计算准确率

根据验证集的结果与真实结果进行比对,计算模型预测准确率。

    /**
     * 根据验证集的结果与真实结果进行比对,计算模型预测准确率
     * @return 预测准确率
     */
    public double getAccuracy() {
        double tempCorrect = 0;
        for (int i = 0; i < predictions.length; i++) {
            if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
                tempCorrect++;
            }
        }
        return tempCorrect / testingSet.length;
    }

详细注释

package knn_nb;

import weka.core.Instances;

import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;

public class Knn {
    //两种距离
    public static final int MANHATTAN = 0;
    public static final int EUCLIDEAN = 1;
    //距离策略
    public int distanceMeasure = EUCLIDEAN;
    //设置随机种子
    public static final Random random = new Random();
    //设置k值
    int numNeighbors = 7;
    //数据集
    Instances dataset;
    //训练集、测试集、验证集数组
    int[] trainingSet;
    int[] testingSet;
    int[] predictions;
    /**
     * 构造方法,用Instances类读取数据集
     * @param paraFilename 数据集地址
     */
    public Knn(String paraFilename) {
        try {
            FileReader fileReader = new FileReader(paraFilename);
            //这里用Instances读取数据集
            dataset = new Instances(fileReader);
            //将当前Istances类的标签下标设置为(数据集的属性数-1)?
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    /**
     * 获得一个随机序列
     * @param paraLength 序列长度
     * @return 打乱后的序列
     */
    public static int[] getRandomIndices(int paraLength) {
        int[] resultIndices = new int[paraLength];
        //按序号赋值
        for(int i = 0; i < paraLength; i++) {
            resultIndices[i] = i;
        }
        //随机交换打乱
        int tempFirst, tempSecond, tempValue;
        for(int i = 0; i < paraLength; i++) {
            tempFirst = random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempFirst] = tempValue;
        }
        return resultIndices;
    }

    /**
     * 划分训练集与验证集,输入训练集的占比,将数据集的下标放入trainingSet和testingSet中
     * @param paraTrainingFraction 训练集比例
     *
     */
    public void splitTrainingTesting(double paraTrainingFraction) {
        //获得数据集记录数
        int tempSize = dataset.numInstances();
        //得到长度为tempSize的随机序列
        int[] tempIndices = getRandomIndices(tempSize);
        //获得训练集记录数
        int tempTrainingSize = (int)(tempSize * paraTrainingFraction);
        //声明训练集与测试集
        trainingSet = new int[tempTrainingSize];
        testingSet = new int[tempSize - tempTrainingSize];
        //为训练集与测试集赋值
        for (int i = 0; i < tempTrainingSize; i++) {
            trainingSet[i] = tempIndices[i];
        }
        for (int i = 0; i < tempSize - tempTrainingSize; i++) {
            testingSet[i] = tempIndices[tempTrainingSize + i];
        }
    }

    /**
     * 将测试集的预测结果放在验证集中
     */
    public void predict() {
        //为验证集赋予空间
        predictions = new int[testingSet.length];
        //将测试集的预测结果放在验证集中
        for (int i = 0; i < predictions.length; i++) {
            predictions[i] = predict(testingSet[i]);
        }
    }

    /**
     * 预测单条记录的标签
     * @param paraIndex 单条记录的下标
     * @return 返回标签值的序号
     */
    public int predict(int paraIndex) {
        //得到距离最近的k个记录
        int[] tempNeighbors = computeNearests(paraIndex);
        //投票得到k个记录中最多的标签值
        int resultPrediction = simpleVoting(tempNeighbors);
        //返回预测标签的序号
        return resultPrediction;
    }

    /**
     * 计算两条记录之间对应的距离
     * @param paraI 一个记录下标
     * @param paraJ 另一条记录下标
     * @return 距离结果
     */
    public double distance(int paraI, int paraJ) {
        double resultDistance = 0;
        double tempDifference;
        //两种距离策略分开计算
        switch (distanceMeasure) {
            case MANHATTAN:
                //对除去标签以外的所有属性进行处理
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    //归一化处理?
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    if (tempDifference < 0) {
                        resultDistance -= tempDifference;
                    } else {
                        resultDistance += tempDifference;
                    }
                }
                break;
            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    resultDistance += tempDifference * tempDifference;
                }
                break;
            default:
                System.out.println("Unsupported distance measure: " + distanceMeasure);
        }
        return resultDistance;
    }

    /**
     * 根据验证集的结果与真实结果进行比对,计算模型预测准确率
     * @return 预测准确率
     */
    public double getAccuracy() {
        double tempCorrect = 0;
        for (int i = 0; i < predictions.length; i++) {
            if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
                tempCorrect++;
            }
        }
        return tempCorrect / testingSet.length;
    }

    /**
     * 得到距离最近的k个记录,存放在数组中
     * @param paraCurrent 当前记录的下标
     * @return 返回k个最近记录
     */
    public int[] computeNearests(int paraCurrent) {
        //声明空间为k的数组
        int[] resultNearests = new int[numNeighbors];
        //声明长度同训练集的数组,作用是判断对应训练集记录是否已经被加入resultNearests数组
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempMinimalDistance;
        int tempMinimalIndex = 0;
        //
        double[] tempDistances = new double[trainingSet.length];
        for (int i = 0; i < trainingSet.length; i ++) {
            tempDistances[i] = distance(paraCurrent, trainingSet[i]);
        }
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                }
                if (tempDistances[j] < tempMinimalDistance) {
                    tempMinimalDistance = tempDistances[j];
                    tempMinimalIndex = j;
                }
            }
            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        }
        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }

    /**
     * 对单条记录进行预测
     * @param paraNeighbors k个最近记录
     * @return 返回预测得到的标签
     */
    public int simpleVoting(int[] paraNeighbors) {
        //以数据集的标签个数建立数组,目的是统计对于当前记录而言,最接近哪个标签
        int[] tempVotes = new int[dataset.numClasses()];
        //循环统计最近的k个记录,将k个记录的所属标签计入数组
        for (int i = 0; i < paraNeighbors.length; i++) {
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
        }
        //选择最多的标签作为自己的标签
        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            }
        }
        //返回所属标签的下标
        return tempMaximalVotingIndex;
    }

    /**
     * 主函数,启动预测并展示准确率
     * @param args
     */
    public static void main(String args[]) {
        Knn knn = new Knn("C:\\Users\\hp\\Desktop\\deepLearning\\src\\main\\java\\knn_nb\\iris.arff");
        knn.splitTrainingTesting(0.8);
        knn.predict();
        System.out.println("The accuracy of the classifier is: " + knn.getAccuracy());
    }

}

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
KNN是一种机器学习算法,主要用于分类和回归。它的工作原理是将每个数据点分配到它最接近的k个邻居之一,然后利用这些邻居的标签来预测新数据点的标签。 以下是使用Java实现KNN算法的示例代码: ```java import java.util.*; public class KNN { // 使用欧几里得距离计算两个数据点之间的距离 public static double euclideanDistance(double[] x, double[] y) { double distance = 0; for (int i = 0; i < x.length; ++i) { distance += Math.pow(x[i] - y[i], 2); } return Math.sqrt(distance); } // 在给定的训练集中查找k个最近邻居 public static int[] nearestNeighbors(double[] x, double[][] data, int k) { double[] distances = new double[data.length]; // 计算x和数据集中每个点的距离 for (int i = 0; i < data.length; ++i) { distances[i] = euclideanDistance(x, data[i]); } // 找到k个最近邻居的索引 int[] neighbors = new int[k]; for (int i = 0; i < k; ++i) { int index = 0; double min = distances[0]; for (int j = 1; j < distances.length; ++j) { if (distances[j] < min) { index = j; min = distances[j]; } } neighbors[i] = index; distances[index] = Double.MAX_VALUE; } return neighbors; } // 对x进行分类 public static String classify(double[] x, double[][] data, String[] labels, int k) { // 找到k个最近邻居的索引 int[] neighbors = nearestNeighbors(x, data, k); // 统计每个类的数量 Map<String, Integer> counts = new HashMap<>(); for (int i = 0; i < neighbors.length; ++i) { String label = labels[neighbors[i]]; counts.put(label, counts.getOrDefault(label, 0) + 1); } // 找到数量最多的类 String result = null; int maxCount = -1; for (String label : counts.keySet()) { int count = counts.get(label); if (count > maxCount) { result = label; maxCount = count; } } return result; } public static void main(String[] args) { double[][] data = new double[][]{{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}}; String[] labels = new String[]{"A", "A", "B", "B", "B"}; double[] x = new double[]{2.5, 2.5}; int k = 3; String result = classify(x, data, labels, k); System.out.println("分类结果:" + result); } } ``` 在这个示例中,我们使用欧几里得距离作为两个数据点之间的距离度量,然后使用nearestNeighbors方法找到最近的k个邻居,最后使用classify方法对新数据点进行分类。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值