Java学习第51天:kNN 分类器

这个代码 300 行, 分三天完成. 今天先把代码抄完并运行, 明后天有修改程序的工作. 要求熟练掌握.

1.两种距离度量.
2.数据随机分割方式.
3.间址的灵活使用: trainingSet 和 testingSet 都是整数数组, 表示下标.
4.arff 文件的读取. 需要 weka.jar 包.
5.求邻居.

下载并安装weka.jar包,加载arff文件,代码如下:

package machinelearning.knn;

import java.util.Arrays;
import java.io.FileReader;
import java.util.Random;

import weka.core.*;

/**
 * @description: knn分类器
 * @author: Qing Zhang
 * @time: 2021/7/1
 */
public class KnnClassification {

    //曼哈顿距离,|x|+|y|
    public static final int MANHATTAN = 0;

    //欧氏距离
    public static final int EUCLIDEAN = 1;

    //距离衡量方式
    public int distanceMeasure = EUCLIDEAN;

    //一个随机实例
    public static final Random random = new Random();

    //邻居数量
    int numNeighbors = 7;

    //存储整个数据集
    Instances dataset;

    //训练集。由数据索引表示
    int[] trainingSet;

    //测试集。由数据索引表示
    int[] testingSet;

    //预测结果
    int[] predictions;

    public KnnClassification(String paraFileName) {
        try {
            FileReader fileReader = new FileReader((paraFileName));
            dataset = new Instances(fileReader);
            //最后一个属性是类别
            dataset.setClassIndex((dataset.numAttributes() - 1));
            fileReader.close();
        } catch (Exception e) {
            System.out.println("Error occurred while trying to read \'" + paraFileName
                    + "\' in KnnClassification constructor.\r\n" + e);
            System.exit(0);
        }
    }

    /**
     * @Description: 获得一个随机的索引用于数据随机化
     * @Param: [paraLength: 序列的长度]
     * @return: int[] e.g., {4, 3, 1, 5, 0, 2} with length 6.
     */
    public static int[] getRandomIndices(int paraLength) {
        int[] resultIndices = new int[paraLength];

        //第一步:初始化
        for (int i = 0; i < paraLength; i++) {
            resultIndices[i] = i;
        }

        //第二步:随机交换
        int tempFirst, tempSecond, tempValue;
        for (int i = 0; i < paraLength; i++) {
            //产生两个随机的索引
            tempFirst = random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);

            //交换
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempSecond] = tempValue;
        }

        return resultIndices;
    }

    /**
     * @Description: 将数据分割为训练集和测试集
     * @Param: [paraTrainingFraction:训练集的占比]
     * @return: void
     */
    public void splitTrainingTesting(double paraTrainingFraction) {
        //numInstances??获取数据集的样本数量。
        int tempSize = dataset.numInstances();
        int[] tempIndices = getRandomIndices(tempSize);
        int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

        trainingSet = new int[tempTrainingSize];
        testingSet = new int[tempSize - tempTrainingSize];

        //将前tempTrainingSize的数据赋给测试集
        for (int i = 0; i < tempTrainingSize; i++) {
            trainingSet[i] = tempIndices[i];
        }

        //将剩余的数据赋给测试集
        for (int i = 0; i < tempSize - tempTrainingSize; i++) {
            testingSet[i] = tempIndices[tempTrainingSize + i];
        }
    }

    /**
     * @Description: 预测整个测试集。结果都存储再预测结果集中
     * @Param: []
     * @return: void
     */
    public void predict() {
        predictions = new int[testingSet.length];
        for (int i = 0; i < predictions.length; i++) {
            predictions[i] = predict(testingSet[i]);
        }
    }

    /**
     * @Description: 预测给定的实例
     * @Param: [paraIndex]
     * @return: int
     */
    public int predict(int paraIndex) {
        int[] tempNeighbors = computeNearests(paraIndex);
        int resultPrediction = simpleVoting(tempNeighbors);

        return resultPrediction;
    }

    /**
     * @Description: 两个实例之间的距离
     * @Param: [paraI, paraJ]
     * @return: double
     */
    public double distance(int paraI, int paraJ) {
        int resultDistance = 0;
        double tempDifference;
        switch (distanceMeasure) {
            case MANHATTAN:
                //numAttributes??
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    if (tempDifference < 0) {
                        resultDistance -= tempDifference;
                    } else {
                        resultDistance += tempDifference;
                    }
                }
                break;
            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    //这里为什么直接是平方??应该是为了效率,将开根省略了
                    resultDistance += tempDifference * tempDifference;
                }
                break;
            default:
                System.out.println("Unsupported distance measure: " + distanceMeasure);
        }

        return resultDistance;
    }

    /**
     * @Description: 获得分类器的正确率
     * @Param: []
     * @return: double
     */
    public double getAccuracy() {
        //一个double除以一个int会得到另一个double
        double tempCorrect = 0;
        for (int i = 0; i < predictions.length; i++) {
            //classValue??就是数据集中的序号
            if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
                tempCorrect++;
            }
        }

        return tempCorrect / testingSet.length;
    }


    /**
     * @Description: 计算最近的k个邻居。在每一轮扫描中选择一个邻居
     * @Param: [paraIndex]
     * @return: int[]
     */
    public int[] computeNearests(int paraCurrent) {
        int[] resultNearests = new int[numNeighbors];
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempDistance;
        double tempMinimalDistance;
        int tempMinimalIndex = 0;

        //选择最近的k个索引
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                }

                tempDistance = distance(paraCurrent, trainingSet[j]);
                if (tempDistance < tempMinimalDistance) {
                    tempMinimalDistance = tempDistance;
                    tempMinimalIndex = j;
                }
            }

            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        }

        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }

    /**
     * @Description: 使用实例投票
     * @Param: [tempNeighbors]
     * @return: int
     */
    public int simpleVoting(int[] paraNeighbors) {
        //numClasses?? 计算每一个类出现的次数
        int[] tempVotes = new int[dataset.numClasses()];
        for (int i = 0; i < paraNeighbors.length; i++) {
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
        }

        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            }
        }

        return tempMaximalVotingIndex;
    }

    public static void main(String[] args) {
        KnnClassification tempClassifier = new KnnClassification("F:\\研究生\\研0\\学习\\Java_Study\\data_set\\iris.arff");
        tempClassifier.splitTrainingTesting(0.8);
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
    }
}

运行结果如下:

 

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值