k近邻算法的Java实现

k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系。输入没有标签的新数据之后,将新数据的每个特征和样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似数据的分类标签作为新数据的标签。一般来说,我们只选取样本数据中前k个最相似的数据。

Java实现:

KNNData.java

package KNN;

public class KNNData implements Comparable<KNNData>{
    double c1;
    double c2;
    double c3;
    double distance;
    String type;
    
    public KNNData(double c1, double c2, double c3, String type) {
        this.c1 = c1;
        this.c2 = c2;
        this.c3 = c3;
        this.type = type;
    }
    
    @Override
    public int compareTo(KNNData arg0) {
        return Double.valueOf(this.distance).compareTo(Double.valueOf(arg0.distance));
    }    
}

KNN.java

package KNN;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

public class KNN {
    
    //训练集
    private List<KNNData> KNNDS = null;
    
    public KNN(List<KNNData> KNNDS) {
        this.KNNDS = KNNDS;
    }
    
    //欧式距离
    private static double disCal(KNNData i, KNNData td) {
        return Math.sqrt((i.c1 - td.c1)*(i.c1 - td.c1)+(i.c2 - td.c2)*(i.c2 - td.c2)+
                (i.c3 - td.c3)*(i.c3 - td.c3));
    }
    
    private static String getMaxValueKey(int k, List<KNNData> ts){
        //只保留前k个元素
        
        while(ts.size() != k) {
            ts.remove(k);
        }
                
        String sKey;
        //保存key以及出现次数
        HashMap<String,Integer> keySet = new HashMap<String,Integer>();
        keySet.put(ts.get(0).type,1);
        for (int x = 1; x < ts.size(); x++) {
            sKey = ts.get(x).type;
            if (keySet.containsKey(sKey)) {
                keySet.put(sKey, keySet.get(sKey)+1);
            } else {
                keySet.put(sKey, 1);
            }
        }
        Set<Map.Entry<String,Integer>> set = keySet.entrySet();
        Iterator<Map.Entry<String,Integer>> iter = set.iterator(); 
        
        int mValue = 0;
        String mType = "";
        while (iter.hasNext()){
            Map.Entry<String,Integer> map = iter.next();
            if (mValue < map.getValue()) {
                mType = map.getKey();
                mValue = map.getValue();
            }
        }
        
        return mType;
    }
    
    public static String knnCal(int k, KNNData i, List<KNNData> ts) {
        //保存距离
        for (KNNData td : ts) {
            td.distance = disCal(i, td);
        }
        Collections.sort(ts);    
        return getMaxValueKey(k, ts);
    }
}

KNNTest.java

package KNN;

import java.util.ArrayList;
import java.util.List;

public class KNNTest {

    public static void main(String[] args) {
        
        List<KNNData> kd = new ArrayList<KNNData>();
        //训练集
        kd.add(new KNNData(1.2,1.1,0.1,"A"));
        kd.add(new KNNData(1.2,1.1,0.1,"A"));
        kd.add(new KNNData(7,1.5,0.1,"B"));
        kd.add(new KNNData(6,1.2,0.1,"B"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(2,2.6,0.1,"C"));
        kd.add(new KNNData(100,1.1,0.1,"D"));

        System.out.println(KNN.knnCal(3, new KNNData(1.1,1.1,0.1,"N/A"), kd));
    }
}

 

转载于:https://www.cnblogs.com/finalboss1987/p/5237441.html

  • 3
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值