Java实现KNN(Weka)

Java实现KNN(Weka)

import java.util.Random;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.lazy.IBk;


public class KNN {
    public static void main(String[] args) throws Exception {
        Instances data = DataSource.read("dataset/colon.arff");
        data.setClassIndex(data.numAttributes() - 1);
        data.randomize(new Random());

        double acc = KFoldFitness(data);
        System.out.println(acc);
    }


//    k折交叉验证
    private static double KFoldFitness(Instances train) {
        train.randomize(new Random());

        int cvNum = 10;

        double acc = 0;
        for(int i = 0; i < cvNum; i++) {
            Instances trainData = train.trainCV(cvNum, i);
            Instances testData = train.testCV(cvNum, i);
//            acc = acc + KNNAcc(trainData, testData);
            double[][] dist = getDistance(trainData, testData);
            int[][] knnIdx = getKNNIndex(dist, 3);
            double[] predictLabel = getPredLabel(knnIdx, trainData);
            double accCv = 0;
            for (int j = 0; j < testData.numInstances(); j++) {
                if (predictLabel[j] == testData.instance(j).classValue()) {
                    accCv += 1;
                }
            }
            acc += accCv / testData.numInstances();
        }

        return acc / cvNum;
    }



//    计算测试样本到训练样本的距离
    private static double[][] getDistance(Instances train, Instances test) {
        double[][] dist = new double[test.numInstances()][train.numInstances()];
        for (int i = 0; i < test.numInstances(); i++) {
            for (int j = 0; j < train.numInstances(); j++) {
                double diff = 0;
                for (int k = 0; k < test.instance(i).numAttributes() - 1; k++) {
                    diff += (test.instance(i).value(k) - train.instance(j).value(k)) * (test.instance(i).value(k) - train.instance(j).value(k));
                }
                dist[i][j] = diff;
            }
        }

        return dist;
    }


//    根据距离选出top k个样本
    private static int[][] getKNNIndex(double[][] dist, int knnK) {
        int[][] knnIdx = new int[dist.length][knnK];
        int idx;
        for (int i = 0; i < dist.length; i++) {
            for (int j = 0; j < knnK; j++) {
                idx = 0;
                for (int k = 0; k < dist[i].length; k++) {
                    if (dist[i][k] < dist[i][idx]) {
                        idx = k;
                    }
                }
                knnIdx[i][j] = idx;
                dist[i][idx] = Double.POSITIVE_INFINITY;
            }
        }

        return knnIdx;
    }


//    根据top k个样本的标签预测测试样本的标签
    private static double[] getPredLabel(int[][] knnIdx, Instances train) {
        double[] y = new double[knnIdx.length];
        double[][] knnLabel = new double[knnIdx.length][knnIdx[0].length];
//        获取top k样本的标签
        for (int i = 0; i < knnIdx.length; i++) {
            for (int j = 0; j < knnIdx[i].length; j++) {
                knnLabel[i][j] = train.instance(knnIdx[i][j]).classValue();
            }
        }

        double[] labels = new double[train.numClasses()];
        for (int i = 0; i < train.numClasses(); i++) {
            labels[i] = Double.parseDouble(train.classAttribute().value(i));
        }
//        计算top k个样本中各个标签的数量
        int[][] voteNum = new int[knnIdx.length][train.numClasses()];
        for (int i = 0; i < knnLabel.length; i++) {
            for (int j = 0; j < knnLabel[i].length; j++) {
                for (int k = 0; k < train.numClasses(); k++) {
                    if (knnLabel[i][j] == k) {
                        voteNum[i][k] += 1;
                    }
                }
            }
        }

//        数量最多的标签便是预测标签
        for (int i = 0; i < voteNum.length; i++) {
            int maxIdx = 0;
            for (int j = 0; j < voteNum[i].length; j++) {
                if (voteNum[i][maxIdx] < voteNum[i][j]) {
                    maxIdx = j;
                }
            }
            y[i] = maxIdx;
        }

        return y;
    }



    //    直接调用weka的包计算
    private static double KNNAcc(Instances train, Instances test) {
        IBk cls = new IBk();
        cls.setKNN(3);
        double[][] conf_matrix = new double[train.numClasses()][train.numClasses()];

        try {
            cls.buildClassifier(train);

            for(int i = 0; i < test.numInstances(); ++i) {
                int pred = (int)cls.classifyInstance(test.instance(i));
                conf_matrix[(int)test.instance(i).classValue()][pred]++;
            }
        } catch (Exception var6) {
            System.out.println("Error");
        }

        double acc = 0;
        int sum = 0;
        for(int i = 0; i < conf_matrix.length; i++) {
            acc += conf_matrix[i][i];
            for(int j = 0; j < conf_matrix[i].length; j++) {
                sum += conf_matrix[i][j];
            }
        }
        acc = acc / sum;

        return acc;
    }
}



评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值