使用Java编程语言实现的KNN算法

1、数据实体类

训练数据TrainRecord和测试数据TestRecord均需要继承Record数据实体类。其中TrainRecord多了一个保存距离的distance成员变量,TestRecord多了一个保存预测类别的predictLabel成员变量。

基类:

public class Record {
    public double[] attributes;
    public int classLabel;
    Record(double[] attributes, int classLabel){
        this.attributes = new double[attributes.length];
        for (int i = 0; i < attributes.length; i++) {
            this.attributes[i] = attributes[i];
        }
        this.classLabel = classLabel;
    }
}

2、计算距离类

所有计算距离类都要实现Metric接口,实现getDistance获取距离函数。下面给出最普通的欧氏距离代码。

public class EuclideanDistance implements Metric {
    @Override
    public double getDistance(Record s, Record e) {
        assert s.attributes.length == e.attributes.length : "s and e are different types of records!";
        int numOfAttributes = s.attributes.length;
        double sum2 = 0;
        for(int i = 0; i < numOfAttributes; i ++){
            sum2 += Math.pow(s.attributes[i] - e.attributes[i], 2);
        }
        return Math.sqrt(sum2);
    }
}

3、找到K个最近邻

KNN算法核心的就是计算距离并返回最核心的K个近邻,首先定义一个保存K近邻的数组。

TrainRecord[] neighbors = new TrainRecord[K];

其次,计算测试实例与所有训练实例使用距离公式计算距离,使用neighbors数组记录最近的K个近邻。

        for(index = K; index < NumOfTrainingSet; index ++){
            trainingSet[index].distance = metric.getDistance(trainingSet[index], testRecord);
            //找出距离最远的实例将其替换掉
            int maxIndex = 0;
            for(int i = 1; i < K; i ++){
                if(neighbors[i].distance > neighbors[maxIndex].distance)
                    maxIndex = i;
            }
            if(neighbors[maxIndex].distance > trainingSet[index].distance)
                neighbors[maxIndex] = trainingSet[index];
        }

4、分类

在分类时,针对k个近邻,距离越远的权值越小。通过投票计算所有类别对应的值,返回最大值对应的类别。

        for(int index = 0;index < num; index ++) {
            TrainRecord temp = neighbors[index];
            int key = temp.classLabel;
            if (!map.containsKey(key))
                map.put(key, 1 / temp.distance);
            else {
                double value = map.get(key);
                value += 1 / temp.distance;
                map.put(key, value);
            }
        }
        double maxSimilarity = 0;
        int returnLabel = -1;
        Set<Integer> labelSet = map.keySet();
        Iterator<Integer> it = labelSet.iterator();
        while(it.hasNext()){
            int label = it.next();
            double value = map.get(label);
            if(value > maxSimilarity){
                maxSimilarity = value;
                returnLabel = label;
            }
        }

结语:计算公式可以进行改进,不单单使用欧氏距离。同时加权的权重可以进行调整,使其更精确。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值