算法原理
k近邻算法非常简单,即在训练集中通过对样本的每个维度加权计算距离,找k个与测试样本最近的样本,统计最有可能的类别。
package knn;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
public class KNN {
//训练集
private List<VisibleData> training;
//距离计算方式
private BiFunction<VisibleData, VisibleData, Number> distance;
public KNN setTraining(List<VisibleData> c) {
this.training=c;
return this;
}
public String test(VisibleData test ,int k) {
//按距离比较
Comparator<VisibleData> com = new Comparator<VisibleData>() {
public int compare(VisibleData o1, VisibleData o2) {
Number d1 = distance.apply(o1, test);
Number d2 = distance.apply(o2, test);
return d1.doubleValue()>d2.doubleValue()?1
: (d1.doubleValue()<d2.doubleValue()?-1:0);
}
};
//流
//排序
//截取前k个元素
//分组统计次数
Map<String, Long> mapk = training.stream()
.sorted(com)
.limit(k)
.collect(Collectors.groupingBy(a->((VisibleData)a).getLabel(), Collectors.counting()));
System.out.println(mapk);
String label=null;
Long count=Long.MAX_VALUE;
for(String key:mapk.keySet()) {
Long tmp = mapk.get(key);
if(count>tmp) {
label=key;
count=tmp;
}
}
return label;
}
public KNN setDistance(BiFunction<VisibleData, VisibleData, Number> fun) {
this.distance=fun;
return this;
}
//随机数
public static ArrayList<Double> randomNumber(Double start,Double end,int num){
ArrayList<Double> list =new ArrayList<Double>();
while(list.size()<num) {
double tmp=Math.random()*(end-start)+start;
list.add(tmp);
}
return list;
}
}
//数据接口
interface VisibleData{
List<Number> getData();
String getLabel();
}
测试数据类
class Data implements VisibleData{
private double x;
private double y;
private String label;
public Data(double x, double y,String label) {
this.x=x;
this.y=y;
this.label=label;
}
public List<Number> getData() {
ArrayList<Number> list = new ArrayList<Number>();
list.add(x);
list.add(y);
return list;
}
public String getLabel() {
return label;
}
public String toString() {
return "("+x+","+y+")";
}
}
测试
public static void main(String[] args) {
List<VisibleData> training=new ArrayList<VisibleData>();
ArrayList<Double> ax = randomNumber(0.0,5.0,10);
ArrayList<Double> ay = randomNumber(0.0,5.0,10);
for(int i=0;i<10;i++) {
training.add(new Data(ax.get(i),ay.get(i),"A"));
}
ArrayList<Double> bx = randomNumber(8.0,10.0,10);
ArrayList<Double> by = randomNumber(8.0,10.0,10);
for(int i=0;i<10;i++) {
training.add(new Data(bx.get(i),by.get(i),"B"));
}
KNN knn=new KNN();
knn.setTraining(training);
knn.setDistance((a,b)->{
List<Double> dis=new ArrayList<Double>();
for(int i=0;i<a.getData().size();i++) {
dis.add(a.getData().get(i).doubleValue()-
b.getData().get(i).doubleValue());
}
return Math.sqrt(dis.stream().map(m->m*m).reduce(0.0,(x,y)->x+y).doubleValue());
});
System.out.println(knn.test(new Data(6.5,6.5,"test"),3));
}
结果分析:
由于(6.5,6.5)选取在两个类的交界处,多次测试将出现不同结果:
- 优点
- 算法简单,无需训练;
- 适合于多分类问题
- 缺点
- 易受样本容量影响,大样本吃小样本
- 复杂度高,测试一个样本就得遍历整个测试集