KNN算法的入门demo

数据准备

![在这里插入图片描述](https://img-blog.csdnimg.cn/20190423160426990.png

代码

package top.yuyufeng.learn;

import top.yuyufeng.learn.model.MyData;
import top.yuyufeng.learn.utils.ExcelReader;

import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/**
 * KNN算法
 *
 * @author yuyufeng
 * @date 2019/4/23.
 */
public class App {
    public static void main(String[] args) {
        //k值
        int k = 5;
        List<MyData> dataList = new LinkedList<>();
        System.out.println("获取样本数据");
        ExcelReader excelReader = new ExcelReader("c://test/data/data1.xlsx");
        List<String[]> sheet0 = excelReader.getAllData(0);
        for (int i = 0; i < sheet0.size(); i++) {
            if (i == 0) {
                continue;
            }
            String[] data = sheet0.get(i);
            MyData myData = new MyData();
            myData.setX(Integer.parseInt(data[0]));
            myData.setY(Integer.parseInt(data[1]));
            myData.setType(data[2]);
            dataList.add(myData);
        }
        System.out.println("获取样本数据结束");


        //新的节点
        MyData myDataNew = new MyData(8, 2, null);

        //获取临近K的节点
        List<MyData> kDatas = new LinkedList<>();
        for (MyData myData : dataList) {
            double distance = calDistance(myData, myDataNew);
            System.out.println(myData + " " + distance);
            if (distance < k) {
                kDatas.add(myData);
            }
        }

        //选举
        Map<String, Integer> keyCounts = new HashMap<>();
        for (MyData kData : kDatas) {
            if (keyCounts.get(kData.getType()) == null) {
                keyCounts.put(kData.getType(), 1);
            } else {
                keyCounts.put(kData.getType(), keyCounts.get(kData.getType()) + 1);
            }
        }

        for (String s : keyCounts.keySet()) {
            System.out.println(s + " " + keyCounts.get(s));
        }

    }

    public static double calDistance(MyData source, MyData target) {
        double result = Math.sqrt(Math.pow(source.getX() - target.getX(), 2) + Math.pow(source.getY() - target.getY(), 2));
        return result;
    }

}

计算结果

获取样本数据
获取样本数据结束
MyData{x=2, y=2, type='C'} 6.0
MyData{x=2, y=3, type='C'} 6.082762530298219
MyData{x=1, y=10, type='A'} 10.63014581273465
MyData{x=5, y=7, type='A'} 5.830951894845301
MyData{x=1, y=5, type='C'} 7.615773105863909
MyData{x=8, y=8, type='B'} 6.0
MyData{x=5, y=10, type='A'} 8.54400374531753
MyData{x=10, y=5, type='D'} 3.605551275463989
MyData{x=9, y=3, type='D'} 1.4142135623730951
MyData{x=10, y=10, type='B'} 8.246211251235321
D 2
发布了154 篇原创文章 · 获赞 142 · 访问量 35万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 编程工作室 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览