Java实现KNN(Weka)

Java实现KNN(Weka)

import java.util.Random;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.classifiers.lazy.IBk;


public class KNN {
    public static void main(String[] args) throws Exception {
        Instances data = DataSource.read("dataset/colon.arff");
        data.setClassIndex(data.numAttributes() - 1);
        data.randomize(new Random());

        double acc = KFoldFitness(data);
        System.out.println(acc);
    }


//    k折交叉验证
    private static double KFoldFitness(Instances train) {
        train.randomize(new Random());

        int cvNum = 10;

        double acc = 0;
        for(int i = 0; i < cvNum; i++) {
            Instances trainData = train.trainCV(cvNum, i);
            Instances testData = train.testCV(cvNum, i);
//            acc = acc + KNNAcc(trainData, testData);
            double[][] dist = getDistance(trainData, testData);
            int[][] knnIdx = getKNNIndex(dist, 3);
            double[] predictLabel = getPredLabel(knnIdx, trainData);
            double accCv = 0;
            for (int j = 0; j < testData.numInstances(); j++) {
                if (predictLabel[j] == testData.instance(j).classValue()) {
                    accCv += 1;
                }
            }
            acc += accCv / testData.numInstances();
        }

        return acc / cvNum;
    }



//    计算测试样本到训练样本的距离
    private static double[][] getDistance(Instances train, Instances test) {
        double[][] dist = new double[test.numInstances()][train.numInstances()];
        for (int i = 0; i < test.numInstances(); i++) {
            for (int j = 0; j < train.numInstances(); j++) {
                double diff = 0;
                for (int k = 0; k < test.instance(i).numAttributes() - 1; k++) {
                    diff += (test.instance(i).value(k) - train.instance(j).value(k)) * (test.instance(i).value(k) - train.instance(j).value(k));
                }
                dist[i][j] = diff;
            }
        }

        return dist;
    }


//    根据距离选出top k个样本
    private static int[][] getKNNIndex(double[][] dist, int knnK) {
        int[][] knnIdx = new int[dist.length][knnK];
        int idx;
        for (int i = 0; i < dist.length; i++) {
            for (int j = 0; j < knnK; j++) {
                idx = 0;
                for (int k = 0; k < dist[i].length; k++) {
                    if (dist[i][k] < dist[i][idx]) {
                        idx = k;
                    }
                }
                knnIdx[i][j] = idx;
                dist[i][idx] = Double.POSITIVE_INFINITY;
            }
        }

        return knnIdx;
    }


//    根据top k个样本的标签预测测试样本的标签
    private static double[] getPredLabel(int[][] knnIdx, Instances train) {
        double[] y = new double[knnIdx.length];
        double[][] knnLabel = new double[knnIdx.length][knnIdx[0].length];
//        获取top k样本的标签
        for (int i = 0; i < knnIdx.length; i++) {
            for (int j = 0; j < knnIdx[i].length; j++) {
                knnLabel[i][j] = train.instance(knnIdx[i][j]).classValue();
            }
        }

        double[] labels = new double[train.numClasses()];
        for (int i = 0; i < train.numClasses(); i++) {
            labels[i] = Double.parseDouble(train.classAttribute().value(i));
        }
//        计算top k个样本中各个标签的数量
        int[][] voteNum = new int[knnIdx.length][train.numClasses()];
        for (int i = 0; i < knnLabel.length; i++) {
            for (int j = 0; j < knnLabel[i].length; j++) {
                for (int k = 0; k < train.numClasses(); k++) {
                    if (knnLabel[i][j] == k) {
                        voteNum[i][k] += 1;
                    }
                }
            }
        }

//        数量最多的标签便是预测标签
        for (int i = 0; i < voteNum.length; i++) {
            int maxIdx = 0;
            for (int j = 0; j < voteNum[i].length; j++) {
                if (voteNum[i][maxIdx] < voteNum[i][j]) {
                    maxIdx = j;
                }
            }
            y[i] = maxIdx;
        }

        return y;
    }



    //    直接调用weka的包计算
    private static double KNNAcc(Instances train, Instances test) {
        IBk cls = new IBk();
        cls.setKNN(3);
        double[][] conf_matrix = new double[train.numClasses()][train.numClasses()];

        try {
            cls.buildClassifier(train);

            for(int i = 0; i < test.numInstances(); ++i) {
                int pred = (int)cls.classifyInstance(test.instance(i));
                conf_matrix[(int)test.instance(i).classValue()][pred]++;
            }
        } catch (Exception var6) {
            System.out.println("Error");
        }

        double acc = 0;
        int sum = 0;
        for(int i = 0; i < conf_matrix.length; i++) {
            acc += conf_matrix[i][i];
            for(int j = 0; j < conf_matrix[i].length; j++) {
                sum += conf_matrix[i][j];
            }
        }
        acc = acc / sum;

        return acc;
    }
}



  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
KNN是一种机器学习算法,主要用于分类和回归。它的工作原理是将每个数据点分配到它最接近的k个邻居之一,然后利用这些邻居的标签来预测新数据点的标签。 以下是使用Java实现KNN算法的示例代码: ```java import java.util.*; public class KNN { // 使用欧几里得距离计算两个数据点之间的距离 public static double euclideanDistance(double[] x, double[] y) { double distance = 0; for (int i = 0; i < x.length; ++i) { distance += Math.pow(x[i] - y[i], 2); } return Math.sqrt(distance); } // 在给定的训练集中查找k个最近邻居 public static int[] nearestNeighbors(double[] x, double[][] data, int k) { double[] distances = new double[data.length]; // 计算x和数据集中每个点的距离 for (int i = 0; i < data.length; ++i) { distances[i] = euclideanDistance(x, data[i]); } // 找到k个最近邻居的索引 int[] neighbors = new int[k]; for (int i = 0; i < k; ++i) { int index = 0; double min = distances[0]; for (int j = 1; j < distances.length; ++j) { if (distances[j] < min) { index = j; min = distances[j]; } } neighbors[i] = index; distances[index] = Double.MAX_VALUE; } return neighbors; } // 对x进行分类 public static String classify(double[] x, double[][] data, String[] labels, int k) { // 找到k个最近邻居的索引 int[] neighbors = nearestNeighbors(x, data, k); // 统计每个类的数量 Map<String, Integer> counts = new HashMap<>(); for (int i = 0; i < neighbors.length; ++i) { String label = labels[neighbors[i]]; counts.put(label, counts.getOrDefault(label, 0) + 1); } // 找到数量最多的类 String result = null; int maxCount = -1; for (String label : counts.keySet()) { int count = counts.get(label); if (count > maxCount) { result = label; maxCount = count; } } return result; } public static void main(String[] args) { double[][] data = new double[][]{{1, 1}, {2, 2}, {3, 3}, {4, 4}, {5, 5}}; String[] labels = new String[]{"A", "A", "B", "B", "B"}; double[] x = new double[]{2.5, 2.5}; int k = 3; String result = classify(x, data, labels, k); System.out.println("分类结果:" + result); } } ``` 在这个示例中,我们使用欧几里得距离作为两个数据点之间的距离度量,然后使用nearestNeighbors方法找到最近的k个邻居,最后使用classify方法对新数据点进行分类。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值