机器学习——KNN分类器的学习

kNN 的特点:

  1. 简单. 没有学习过程, 也被称为惰性学习 lazy learning. 类似于开卷考试, 在已有数据中去找答案.
  2. 本源. 找相似, 正是人类认识事物的常用方法, 隐藏于人类或者其他动物的基因里面. 当然, 人类也会上当,例如有人把邻居的滴水观音误认为是芋头, 偷食后中毒.
  3. 效果好. 永远不要小视 kNN, 对于很多数据, 你很难设计算法超越它.
  4. 适应性强. 可用于分类, 回归. 可用于各种数据.
  5. 可扩展性强. 设计不同的度量, 可获得意想不到的效果.
  6. 一般需要对数据归一化.
  7. 复杂度高. 这也是 kNN 最重要的缺点. 对于每一个测试数据, 复杂度为 O ( ( m + k ) n ) , 其中 n 为训练数据个数, m为条件属性个数, k为邻居个数. 代码见 computeNearests().

代码:

package machinelearning.knn;

import weka.core.*;

import java.io.FileNotFoundException;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

public class KnnClassification {

    //曼哈顿距离,|x|+|y|
    public static final int MANHATTAN = 0;

    //欧氏距离
    public static final int EUCLIDEAN = 1;

    //距离衡量方式
    public int distanceMeasure = EUCLIDEAN;

    //一个随机实例
    public static final Random random = new Random();

    //邻居的数量
    int numNeighbors = 7;

    //存储整个数据集
    Instances dataset;

    //训练集,由数据索引表示
    int[] trainingSet;

    //测试集,由数据索引表示
    int[] testingSet;

    //预测结果
    int[] predictions;


    public KnnClassification(String paraFilename) {
        try {
            FileReader fileReader = new FileReader(paraFilename);
            dataset = new Instances(fileReader);
            //最后一个属性是类别
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (Exception e) {
            System.out.println("Error occurred while trying to read \'" + paraFilename
                    + "\' in KnnClassification constructor.\r\n" + e);
            System.exit(0);
        }
    }

    /**
     * 获得一个随机索引用于数据随机化
     *
     * @param paraLength 数据的长度
     * @return 返回一个索引数组
     */
    public static int[] getRandomIndices(int paraLength) {
        int[] resultIndices = new int[paraLength];

        //1. 初始化
        for (int i = 0; i < paraLength; i++) {
            resultIndices[i] = i;
        }

        //2. 随机交换
        int tempFirst, tempSecond, tempValue;
        for (int i = 0; i < paraLength; i++) {
            //产生两个随机索引
            tempFirst = random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);

            //交换
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempSecond] = tempValue;
        }
        return resultIndices;
    }

    /**
     * 将数据分为训练集与测试集
     *
     * @param paraTrainingFraction 训练集所占比例
     */
    public void splitTrainingTesting(double paraTrainingFraction) {
        int tempSize = dataset.numInstances();//数据集所含数据的数量
        int[] tempIndices = getRandomIndices(tempSize);
        int tempTrainingSize = (int) (tempSize * paraTrainingFraction);

        trainingSet = new int[tempTrainingSize];
        testingSet = new int[tempSize - tempTrainingSize];

        for (int i = 0; i < tempTrainingSize; i++) {
            trainingSet[i] = tempIndices[i];
        }

        for (int i = 0; i < tempSize - tempTrainingSize; i++) {
            testingSet[i] = tempIndices[tempTrainingSize + i];
        }

    }

    /**
     * 预测整个测试集,结果存储在预测集中
     */
    public void predict() {
        predictions = new int[testingSet.length];
        for (int i = 0; i < predictions.length; i++) {
            predictions[i] = predict(testingSet[i]);
        }
    }

    /**
     * 预测给定的实例
     *
     * @param paraIndex
     * @return 预测的结果
     */
    private int predict(int paraIndex) {
        int[] tempNeighbors = computeNearests(paraIndex);
        int resultPrediction = simpleVoting(tempNeighbors);

        return resultPrediction;
    }

    /**
     * 两个实例之间的距离
     *
     * @param paraI 第一个实例的索引
     * @param paraJ 第二个实例的索引
     * @return 距离
     */
    public double distance(int paraI, int paraJ) {
        double resultDistance = 0;
        double tempDifference;
        switch (distanceMeasure) {
            case MANHATTAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    if (tempDifference < 0) {
                        resultDistance -= tempDifference;
                    } else {
                        resultDistance += tempDifference;
                    }
                }
                break;
            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    resultDistance += tempDifference * tempDifference;
                }
                break;
            default:
                System.out.println("Unsupported distance measure: " + distanceMeasure);
        }
        return resultDistance;
    }

    /**
     * 获取分类器的准确度
     *
     * @return
     */
    public double getAccuracy() {
        double tempCorrect = 0;
        for (int i = 0; i < predictions.length; i++) {
            if (predictions[i] == dataset.instance(testingSet[i]).classValue()) {
                tempCorrect++;
            }
        }
        return tempCorrect / testingSet.length;
    }

    /**
     * 计算最近的n个邻居
     *
     * @param paraCurrent 最近的实例
     * @return 最近实例的索引
     */
    private int[] computeNearests(int paraCurrent) {

        int[] resultNearests = new int[numNeighbors];
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempMinimalDistance;
        int tempMinimalIndex = 0;

        double[] tempDistances = new double[trainingSet.length];
        for (int i = 0; i < trainingSet.length; i++) {
            tempDistances[i] = distance(paraCurrent, trainingSet[i]);
        }

        //选择最近的k个索引
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                }

                if (tempDistances[j] < tempMinimalDistance) {
                    tempMinimalDistance = tempDistances[j];
                    tempMinimalIndex = j;
                }
            }
            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        }

        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }




    /**
     * 投票
     *
     * @param paraNeighbors
     * @return
     */
    private int simpleVoting(int[] paraNeighbors) {
        int[] tempVotes = new int[dataset.numClasses()];
        for (int i = 0; i < paraNeighbors.length; i++) {
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
        }

        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            }
        }
        return tempMaximalVotingIndex;
    }

    public void setDistanceMeasure(int paraType) {
        if (paraType == 0) {
            distanceMeasure = MANHATTAN;
        } else if (paraType == 1) {
            distanceMeasure = EUCLIDEAN;
        } else {
            System.out.println("Wrong Distance Measure!!!");
        }
    }

    public void setNumNeighbors(int paraNumNeighbors) {
        if (paraNumNeighbors > dataset.numInstances()) {
            System.out.println("out of range");
            return;
        }
        this.numNeighbors = paraNumNeighbors;
    }

    public static void main(String args[]) {
        KnnClassification tempClassifier = new KnnClassification("D:\\研究生学习\\iris.arff");
        tempClassifier.splitTrainingTesting(0.8);
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
    }
}

运行结果:

The nearest of 120 are: [143, 140, 124, 144, 112, 139, 102]
The nearest of 3 are: [29, 2, 45, 12, 38, 42, 34]
The nearest of 64 are: [82, 79, 88, 99, 59, 92, 89]
The nearest of 37 are: [34, 9, 1, 12, 29, 45, 2]
The nearest of 148 are: [136, 115, 147, 140, 137, 124, 144]
The nearest of 30 are: [29, 34, 9, 45, 12, 1, 11]
The nearest of 126 are: [123, 127, 138, 146, 83, 63, 72]
The nearest of 117 are: [131, 105, 109, 122, 125, 107, 118]
The nearest of 55 are: [66, 96, 94, 78, 95, 99, 84]
The nearest of 47 are: [2, 42, 6, 29, 38, 12, 45]
The nearest of 90 are: [94, 96, 89, 99, 67, 95, 92]
The nearest of 71 are: [97, 82, 92, 61, 99, 74, 67]
The nearest of 132 are: [128, 104, 103, 111, 112, 140, 147]
The nearest of 49 are: [7, 39, 0, 28, 17, 40, 34]
The nearest of 134 are: [103, 83, 111, 137, 119, 72, 108]
The nearest of 35 are: [1, 2, 40, 28, 34, 9, 7]
The nearest of 10 are: [48, 27, 36, 19, 5, 16, 20]
The nearest of 130 are: [107, 102, 125, 129, 105, 122, 108]
The nearest of 15 are: [33, 14, 5, 16, 32, 48, 19]
The nearest of 8 are: [38, 42, 13, 12, 45, 2, 29]
The nearest of 133 are: [83, 72, 123, 127, 63, 111, 77]
The nearest of 18 are: [5, 48, 20, 16, 31, 36, 33]
The nearest of 69 are: [80, 89, 81, 92, 82, 53, 67]
The nearest of 135 are: [105, 102, 107, 122, 125, 109, 118]
The nearest of 25 are: [34, 9, 1, 12, 45, 29, 7]
The nearest of 46 are: [19, 21, 48, 4, 27, 32, 44]
The nearest of 110 are: [147, 115, 77, 137, 141, 139, 127]
The nearest of 116 are: [137, 103, 147, 111, 128, 112, 104]
The nearest of 145 are: [141, 147, 139, 112, 115, 140, 128]
The nearest of 149 are: [127, 138, 142, 101, 70, 83, 121]
The accuracy of the classifier is: 0.9666666666666667
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值