Day 53 kNN分类器

53.1增加 weightedVoting() 方法, 距离越短话语权越大. 支持两种以上的加权方式.

    /**
     *********************
     * Voting the closest neighbor.
     *
     * @param paraCurrent   current instance. We are comparing it with all others.
     * @param paraNeighbors The indices of the neighbors.
     * @return The predicted label.
     *********************
     */
    public int weightedVoting(int paraCurrent, int[] paraNeighbors) {
        int tempMinIndex = -1;
        double tempMinValue = Double.MAX_VALUE;
        double tempDistance;

        for (int i = 0; i < paraNeighbors.length; i++) {
            tempDistance = distance(paraCurrent, paraNeighbors[i]);
            if (tempDistance < tempMinValue) {
                tempMinIndex = i;
                tempMinValue = tempDistance;
            } // Of if
        } // Of for i

        return (int) dataset.instance(paraNeighbors[tempMinIndex]).classValue();
    }// Of weightedVoting

53.2实现 leave-one-out 测试。

package dl;

import java.io.FileReader;
import java.util.Arrays;

import weka.core.*;

/**
 * kNN classification for leave-one-out measure to test.
 */

public class KnnClassificationLeaveOneOut {

    /**
     * Manhattan distance.
     */
    public static final int MANHATTAN = 0;

    /**
     * Euclidean distance.
     */
    public static final int EUCLIDEAN = 1;

    /**
     * The distance measure.
     */
    public int distanceMeasure = EUCLIDEAN;

    /**
     * The number of neighbors.
     */
    int numNeighbors = 7;

    /**
     * The whole dataset.
     */
    Instances dataset;

    /**
     * The training set. Represented by the indices of the data.
     */
    int[] trainingSet;

    /**
     * The predictions.
     */
    int[] predictions;

    /**
     *********************
     * The first constructor.
     *
     * @param paraFilename The arff filename.
     *********************
     */
    public KnnClassificationLeaveOneOut(String paraFilename) {
        try {
            FileReader fileReader = new FileReader(paraFilename);
            dataset = new Instances(fileReader);
            // The last attribute is the decision class.
            dataset.setClassIndex(dataset.numAttributes() - 1);
            fileReader.close();
        } catch (Exception ee) {
            System.out.println("Error occurred while trying to read \'" + paraFilename
                    + "\' in KnnClassification constructor.\r\n" + ee);
            System.exit(0);
        } // Of try
    }// Of the first constructor

    /**
     *********************
     * Obtain trainingSet from dataset.
     *********************
     */
    public void setTrainingSet() {
        int tempSize = dataset.numInstances();

        trainingSet = new int[tempSize];
        for (int i = 0; i < tempSize; i++) {
            trainingSet[i] = i;
        } // Of for i
    }// Of setTrainingSet

    /**
     *********************
     * Predict for the whole testing set. The results are stored in predictions.
     * #see predictions.
     *********************
     */
    public void predict() {
        predictions = new int[dataset.numInstances()];
        for (int i = 0; i < dataset.numInstances(); i++) {
            System.out.print("Try to predict " + i + " row of data: ");
            predictions[i] = predict(i);
            System.out.println("Prediction class is " + predictions[i]);
        } // Of for i
    }// Of predict

    /**
     *********************
     * Predict for given instance.
     *
     * @return The prediction.
     *********************
     */
    public int predict(int paraIndex) {
        int[] tempNeighbors = computeNearests(paraIndex);
        int resultPrediction = simpleVoting(tempNeighbors);

        return resultPrediction;
    }// Of predict

    /**
     *********************
     * The distance between two instances.
     *
     * @param paraI The index of the first instance.
     * @param paraJ The index of the second instance.
     * @return The distance.
     *********************
     */
    public double distance(int paraI, int paraJ) {
        double resultDistance = 0;
        double tempDifference;
        switch (distanceMeasure) {
            case MANHATTAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    if (tempDifference < 0) {
                        resultDistance -= tempDifference;
                    } else {
                        resultDistance += tempDifference;
                    } // Of if
                } // Of for i
                break;

            case EUCLIDEAN:
                for (int i = 0; i < dataset.numAttributes() - 1; i++) {
                    tempDifference = dataset.instance(paraI).value(i) - dataset.instance(paraJ).value(i);
                    resultDistance += tempDifference * tempDifference;
                } // Of for i
                break;
            default:
                System.out.println("Unsupported distance measure: " + distanceMeasure);
        }// Of switch

        return resultDistance;
    }// Of distance

    /**
     *********************
     * Get the accuracy of the classifier.
     *
     * @return The accuracy.
     *********************
     */
    public double getAccuracy() {
        // A double divides an int gets another double.
        double tempCorrect = 0;
        for (int i = 0; i < predictions.length; i++) {
            if (predictions[i] == dataset.instance(i).classValue()) {
                tempCorrect++;
            } // Of if
        } // Of for i

        return tempCorrect / predictions.length;
    }// Of getAccuracy

    /**
     ************************************
     * Compute the nearest k neighbors. Select one neighbor in each scan. In fact we
     * can scan only once. You may implement it by yourself.
     *
     * @param paraCurrent current instance. We are comparing it with all others.
     * @return the indices of the nearest instances.
     ************************************
     */
    public int[] computeNearests(int paraCurrent) {
        int[] resultNearests = new int[numNeighbors];
        boolean[] tempSelected = new boolean[trainingSet.length];
        double tempMinimalDistance;
        int tempMinimalIndex = 0;
        tempSelected[paraCurrent] = true;

        // Compute all distances to avoid redundant computation.
        double[] tempDistances = new double[trainingSet.length];
        for (int i = 0; i < trainingSet.length; i++) {
            tempDistances[i] = distance(paraCurrent, trainingSet[i]);
        } // Of for i

        // Select the nearest paraK indices.
        for (int i = 0; i < numNeighbors; i++) {
            tempMinimalDistance = Double.MAX_VALUE;

            for (int j = 0; j < trainingSet.length; j++) {
                if (tempSelected[j]) {
                    continue;
                } // Of if

                if (tempDistances[j] < tempMinimalDistance) {
                    tempMinimalDistance = tempDistances[j];
                    tempMinimalIndex = j;
                } // Of if
            } // Of for j

            resultNearests[i] = trainingSet[tempMinimalIndex];
            tempSelected[tempMinimalIndex] = true;
        } // Of for i

//		System.out.println("The nearest of " + paraCurrent + " are: " + Arrays.toString(resultNearests));
        return resultNearests;
    }// Of computeNearests

    /**
     ************************************
     * Voting using the instances.
     *
     * @param paraNeighbors The indices of the neighbors.
     * @return The predicted label.
     ************************************
     */
    public int simpleVoting(int[] paraNeighbors) {
        int[] tempVotes = new int[dataset.numClasses()];
        for (int i = 0; i < paraNeighbors.length; i++) {
            tempVotes[(int) dataset.instance(paraNeighbors[i]).classValue()]++;
        } // Of for i

        int tempMaximalVotingIndex = 0;
        int tempMaximalVoting = 0;
        for (int i = 0; i < dataset.numClasses(); i++) {
            if (tempVotes[i] > tempMaximalVoting) {
                tempMaximalVoting = tempVotes[i];
                tempMaximalVotingIndex = i;
            } // Of if
        } // Of for i

        return tempMaximalVotingIndex;
    }// Of simpleVoting

    /**
     ************************************
     * Set numNeighbors.
     *
     * @param paraNeighbors The num of Neighbors.
     ************************************
     */
    public void setNumNeighors(int paraNeighbors) {
        numNeighbors = paraNeighbors;
        return;
    }// Of setNumNeighors

    /**
     *********************
     * Set the distance measure
     *
     * @param paraDistanceMeasure The class of distance measure 0 symbolize
     *                            MANHATTAN 1 symbolize EUCLIDEAN
     *********************
     */
    public void setDistanceMeasure(int paraDistanceMeasure) {
        distanceMeasure = paraDistanceMeasure;
        return;
    }// Of setDistanceMeasure

    /**
     *********************
     * Voting the closest neighbor.
     *
     * @param paraCurrent   current instance. We are comparing it with all others.
     * @param paraNeighbors The indices of the neighbors.
     * @return The predicted label.
     *********************
     */
    public int weightedVoting(int paraCurrent, int[] paraNeighbors) {
        int tempMinIndex = -1;
        double tempMinValue = Double.MAX_VALUE;
        double tempDistance;

        for (int i = 0; i < paraNeighbors.length; i++) {
            tempDistance = distance(paraCurrent, paraNeighbors[i]);
            if (tempDistance < tempMinValue) {
                tempMinIndex = i;
                tempMinValue = tempDistance;
            } // Of if
        } // Of for i

        return (int) dataset.instance(paraNeighbors[tempMinIndex]).classValue();
    }// Of weightedVoting

    /**
     *********************
     * The entrance of the program.
     *
     * @param args Not used now.
     *********************
     */
    public static void main(String args[]) {
        KnnClassificationLeaveOneOut tempClassifier = new KnnClassificationLeaveOneOut(
                "C:\\Users\\86183\\IdeaProjects\\deepLearning\\src\\main\\java\\resources\\iris.arff");
        tempClassifier.setTrainingSet();
        tempClassifier.predict();
        System.out.println("The accuracy of the classifier is: " + tempClassifier.getAccuracy());
    }// Of main

}// Of class KnnClassificationLeaveOneOut

结果:

 

说明:本文代码参考张星移学长代码

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值