knn算法

算法原理

k近邻算法非常简单,即在训练集中通过对样本的每个维度加权计算距离,找k个与测试样本最近的样本,统计最有可能的类别。

package knn;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

public class KNN {
	//训练集
	private List<VisibleData> training;
	//距离计算方式
	private BiFunction<VisibleData, VisibleData, Number> distance;
	public KNN setTraining(List<VisibleData> c) {
		this.training=c;
		return this;
	}
	public  String test(VisibleData test ,int k) {
		//按距离比较
		Comparator<VisibleData> com = new Comparator<VisibleData>() {
			public int compare(VisibleData o1, VisibleData o2) {
				Number d1 = distance.apply(o1, test);
				Number d2 = distance.apply(o2, test);
				return d1.doubleValue()>d2.doubleValue()?1
						: (d1.doubleValue()<d2.doubleValue()?-1:0);
			}
		};
		
		//流
		//排序
		//截取前k个元素
		//分组统计次数
		Map<String, Long> mapk = training.stream()
		.sorted(com)
		.limit(k)
		.collect(Collectors.groupingBy(a->((VisibleData)a).getLabel(), Collectors.counting()));
		
		System.out.println(mapk);
		
		String label=null;
		Long count=Long.MAX_VALUE;
		for(String key:mapk.keySet()) {
			Long tmp = mapk.get(key);
			if(count>tmp) {
				label=key;
				count=tmp;
			}
		}
		return label;
	}
	public  KNN setDistance(BiFunction<VisibleData, VisibleData, Number> fun) {
		this.distance=fun;
		return this;
	}
	//随机数
	public static ArrayList<Double> randomNumber(Double start,Double end,int num){
		ArrayList<Double> list =new ArrayList<Double>();
		while(list.size()<num) {
			double tmp=Math.random()*(end-start)+start;
			list.add(tmp);
		}
		return list;
	}
}
//数据接口
interface VisibleData{
	List<Number> getData();
	String getLabel();
}
测试数据类
class Data implements VisibleData{
	private double x;
	private double y;
	private String label;
	public Data(double x, double y,String label) {
		this.x=x;
		this.y=y;
		this.label=label;
	}
	public List<Number> getData() {
		ArrayList<Number> list = new ArrayList<Number>();
		list.add(x);
		list.add(y);
		return list;
	}
	public String getLabel() {
		return label;
	}
	public String toString() {
		return "("+x+","+y+")";
	}
	
}
测试
public static void main(String[] args) {
		List<VisibleData> training=new ArrayList<VisibleData>();
		ArrayList<Double> ax = randomNumber(0.0,5.0,10);
		ArrayList<Double> ay = randomNumber(0.0,5.0,10);
		for(int i=0;i<10;i++) {
			training.add(new Data(ax.get(i),ay.get(i),"A"));
		}
		ArrayList<Double> bx = randomNumber(8.0,10.0,10);
		ArrayList<Double> by = randomNumber(8.0,10.0,10);
		for(int i=0;i<10;i++) {
			training.add(new Data(bx.get(i),by.get(i),"B"));
		}
		KNN knn=new KNN();
		knn.setTraining(training);
		knn.setDistance((a,b)->{
			List<Double> dis=new ArrayList<Double>();
			for(int i=0;i<a.getData().size();i++) {
				dis.add(a.getData().get(i).doubleValue()-
						b.getData().get(i).doubleValue());
			}
			return Math.sqrt(dis.stream().map(m->m*m).reduce(0.0,(x,y)->x+y).doubleValue());
		});
		System.out.println(knn.test(new Data(6.5,6.5,"test"),3));
	}

结果分析:
由于(6.5,6.5)选取在两个类的交界处,多次测试将出现不同结果:

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

  • 优点
    • 算法简单,无需训练;
    • 适合于多分类问题
  • 缺点
    • 易受样本容量影响,大样本吃小样本
    • 复杂度高,测试一个样本就得遍历整个测试集
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值