日撸java 三百行 趁热打铁(01)KNN分类器

前两天刚听完闵帆老师的讲座,赶紧趁着热乎开撸。

KNN是机器学习里面的入门基础算法之一,但是它的普适性很强,对于新的问题,把KNN拿出来缝缝补补改改它又能战斗了,所以可以把它当做算法检测标杆

KNN思想(人类的比较思维):要判断一个未知的事物,可以找一个我们知道并且与之最相似的事物,我们就认为它俩是同一种事物。那么具体落到计算机上要怎么实现呢?

其最主要的就是要模拟找相似的过程,对于输入的一个向量,可以考虑衡量它与已知数据的距离,如果距离值越小,就认为两组数组很接近。

距离的计算方式常用有两种,曼哈顿距离与欧几里得距离(这两名字听起来特别离谱,实际上就是初中学的),接着我们找到与之距离最近的K个邻居,K个邻居投票决定待测样本的类别。K的取值对结果的影响还是比较大的,取小了过拟合,取大了欠拟合(这个地方有点迷糊),K尽量取奇数可以避免投票出现平票的尴尬情况。

package com.trian;
import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;

import myself.test;
import weka.core.*;

/**
 * kNN classification.
 *
 */
public class KnnClassification {
    /**
     * Manhattan distance.
     */
    public static final int  MANHATTAN = 0;

    /**
     * Euclidean distance.
     */
    public static final int EUCLIDEAN = 1;


    /**
     * The default distance measure.
     */
    public int distanceMeasure = EUCLIDEAN;

    /**
     * A random instance;
     */
    public static final Random random = new Random();

    /**
     * The Number of reference neighbors
     */
    int numNeighbors=7;

    /**
     * The dataset
     */
    Instances dataset;

    /**
     * The trainingSet
     * The TestingSet
     * The predictions
     */
    int []trainingSet;
    int []testingSet;
    int []predictions;

    /*
     *********************
     * @param paraFilename
     *              The arff filename.
     * @return null
     *
     *********************
     */
    public KnnClassification(String paraFilename){
        try{
            FileReader fileReader = new FileReader(paraFilename);
            dataset=new Instances(fileReader);
            dataset.setClassIndex(dataset.numAttributes()-1);
            fileReader.close();
        }catch(Exception ee){
            System.out.println("404 NOT Found");
            System.exit(0);
        }//of try
    }//of KnnClassification

    /*
     *********************
     * @param paraLength
     *            The The length of the sequence.
     * @return int[]
     *
     *********************
     */
    public static int[] getRandomIndices(int paraLength){
        int[] resultIndices = new int[paraLength];
        for(int i=0;i<paraLength;++i){
            resultIndices[i]=i;
        }// Of for i
        int tempFirst, tempSecond, tempValue;
        for(int i=0;i<paraLength;++i){
            tempFirst=random.nextInt(paraLength);
            tempSecond = random.nextInt(paraLength);
            tempValue = resultIndices[tempFirst];
            resultIndices[tempFirst] = resultIndices[tempSecond];
            resultIndices[tempSecond] = tempValue;
        }//of for i
        return resultIndices;
    }//of getRandomIndices

    /*
     *********************
     * @param paraTrainingFraction
     *              The fraction of the training set.
     * @return void
     *
     *********************
     */
    public void splitTrainingTesting(double paraTrainingFraction){
        int [] tempIndices=getRandomIndices(dataset.numInstances());
        int tempTrainingSize=(int)paraTrainingFraction*dataset.numInstances();
        trainingSet = new int[tempTrainingSize];
        testingSet = new int[dataset.numInstances() - tempTrainingSize];

        for (int i = 0; i < tempTrainingSize; i++) {
            trainingSet[i] = tempIndices[i];
        } // Of for i

        for (int i = 0; i < dataset.numInstances() - tempTrainingSize; i++) {
            testingSet[i] = tempIndices[tempTrainingSize + i];
        } // Of for i
    }//of splitTrainingTesting

    /*
     *********************
     * @return void
     *********************
     */
    public void predict(){
        predictions=new int[testingSet.length];
        for(int i=0;i<predictions.length;++i){
            predictions[i]=predict(testingSet[i]);
        }//of for i
    }//of predict

    /*
     *********************
     * @param paraIndex
     *          Predict for given instance.
     * @return int
     *
     *********************
     */
    public int predict(int paraIndex){
        int[] tempNeighbors = computeNearests(paraIndex);
        int resultPrediction = simpleVoting(tempNeighbors);

        return resultPrediction;
    }// of predict

    public int[] computeNearests(int paraCurrent) {
        int[] resultNearests = new int[numNeighbors];
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempMinimalDistance;
        int tempMinimalIndex = 0;

        double[] tempDistances = new double[trainingSet.length];
        for (int i = 0; i < trainingSet.length; ++i) {
            tempDistances[i] = distance(paraCurrent, trainingSet[i]);
        }//of for i
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                } // Of if

                if (tempDistances[j] < tempMinimalDistance) {
                    tempMinimalDistance = tempDistances[j];
                    tempMinimalIndex = j;
                } // Of if
            } // Of for j

            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        }//of for i
        System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }// of computeNearests

    public double distance(int para1,int para2){
        double resultDistance = 0;
        double tempDifference;
        switch (distanceMeasure){
            case MANHATTAN:
                for(int i=0;i<dataset.numAttributes()-1;++i){
                    tempDifference=dataset.instance(para1).value(i)-dataset.instance(para2).value(i);
                    Math.abs(tempDifference);
                    resultDistance+=tempDifference;
                }//of for i
            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++){
                    tempDifference=dataset.instance(para1).value(i)-dataset.instance(para2).value(i);
                    resultDistance+=tempDifference*tempDifference;
                }//of for i
        }//of switch
        return resultDistance;
    }// of distance

    public int simpleVoting(int []paraNeighbors){
        int[] tempVotes = new int[dataset.numClasses()];
        for(int i=0;i<numNeighbors;++i){
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
        }//of for i
        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            } // Of if
        } // Of for i

        return tempMaximalVotingIndex;
    }//of simpleVoting


    public  double getAccuracy() {
        double tempCorrect = 0;
        for(int i=0;i<predictions.length;++i){
            if(predictions[i]==(int)dataset.instance(testingSet[i]).classValue()){
                tempCorrect++;
            } // Of if
        } // Of for i
        return tempCorrect/testingSet.length;
    }// of getAccuracy

    public static void main(String args[]) {
        test tempClassifier = new test("C:/Users/胡来的魔术师/Desktop/sampledata-main/iris.arff");
        tempClassifier.splitTrainingTesting(0.95);
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
    }// Of main

}// Of class KnnClassification

运行结果:

 代码分析:

dataset用于管理arff数据集,数据预处理首先随机打乱数据,然后根据自己的喜好分配训练集与测试集,computeNearests(paraIndex)方法会选出与当前参数最近的k个邻居

simpleVoting(tempNeighbors)方法会依次遍历K个邻居并记录他们的类别数量,最后返回数量最多的那个类别。

distance(int para1,int para2)计算两组数据间的距离。

总结:knn算法很容易理解,感觉也比较暴力。理清这个dataset助手后代码读写也会更容易,instance表示行,也就是一行数据,一个数据对象,Attributes表示列,也就是属性。虽然knn是入门的第一个算法,但是也不能小看它,由于它的思想过于简单没有那么多弯弯绕绕,刚接触的时候难免让人觉得离谱,但是仔细一想,我们人类认识新事物好像绝大部分时候也是这么简单。knn没有学习过程,直接在训练集里找“答案”。


                
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值