数据准备
代码
package top.yuyufeng.learn;
import top.yuyufeng.learn.model.MyData;
import top.yuyufeng.learn.utils.ExcelReader;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
/**
* KNN算法
*
* @author yuyufeng
* @date 2019/4/23.
*/
public class App {
public static void main(String[] args) {
//k值
int k = 5;
List<MyData> dataList = new LinkedList<>();
System.out.println("获取样本数据");
ExcelReader excelReader = new ExcelReader("c://test/data/data1.xlsx");
List<String[]> sheet0 = excelReader.getAllData(0);
for (int i = 0; i < sheet0.size(); i++) {
if (i == 0) {
continue;
}
String[] data = sheet0.get(i);
MyData myData = new MyData();
myData.setX(Integer.parseInt(data[0]));
myData.setY(Integer.parseInt(data[1]));
myData.setType(data[2]);
dataList.add(myData);
}
System.out.println("获取样本数据结束");
//新的节点
MyData myDataNew = new MyData(8, 2, null);
//获取临近K的节点
List<MyData> kDatas = new LinkedList<>();
for (MyData myData : dataList) {
double distance = calDistance(myData, myDataNew);
System.out.println(myData + " " + distance);
if (distance < k) {
kDatas.add(myData);
}
}
//选举
Map<String, Integer> keyCounts = new HashMap<>();
for (MyData kData : kDatas) {
if (keyCounts.get(kData.getType()) == null) {
keyCounts.put(kData.getType(), 1);
} else {
keyCounts.put(kData.getType(), keyCounts.get(kData.getType()) + 1);
}
}
for (String s : keyCounts.keySet()) {
System.out.println(s + " " + keyCounts.get(s));
}
}
public static double calDistance(MyData source, MyData target) {
double result = Math.sqrt(Math.pow(source.getX() - target.getX(), 2) + Math.pow(source.getY() - target.getY(), 2));
return result;
}
}
计算结果
获取样本数据
获取样本数据结束
MyData{x=2, y=2, type='C'} 6.0
MyData{x=2, y=3, type='C'} 6.082762530298219
MyData{x=1, y=10, type='A'} 10.63014581273465
MyData{x=5, y=7, type='A'} 5.830951894845301
MyData{x=1, y=5, type='C'} 7.615773105863909
MyData{x=8, y=8, type='B'} 6.0
MyData{x=5, y=10, type='A'} 8.54400374531753
MyData{x=10, y=5, type='D'} 3.605551275463989
MyData{x=9, y=3, type='D'} 1.4142135623730951
MyData{x=10, y=10, type='B'} 8.246211251235321
D 2