源代码:http://t.csdnimg.cn/fYpFf
-
背景介绍
k-近邻是在训练集中选取离输入的数据点最近的k个邻居,根据这个k个邻居中出现次数最多的类别,作为该数据点的类别。
-
实验内容
实验原理:
K-nearst neighbors是一种基本的机器学习算法,所谓k近邻,就是k个最近的邻居的意思,说的是每个样本都可以用它最接近的k个邻居来代表。比如:判断一个人的人品,只需要观察与他来往最密切的几个人的人品好坏就可以得出,即“近朱者赤,近墨者黑”;KNN算法既可以应用于分类应用中,也可以应用在回归应用中。
KNN算法的思想:计算每个训练数据到待分类元组的距离,取和待分类元组最近的k个训练数据,k个数据中哪个类别的训练数据占多数,则待分类元组就属于哪个类别。
距离度量的方式有三种:欧式距离、曼哈顿距离、闵可夫斯基距离。
流程图:
图2-1:knn算法流程图
算法:
图2-2:knn算法伪代码图
步骤:
(1)准备数据,对数据进行预处理;
(2)计算测试样本点(也就是待分类点)到其他每个样本点的距离;
(3)对每个距离进行排序,然后选择出距离最小的个点;
(4)对个点所属的类别进行比较,根据多数表决原则,将测试样本点归入在个点中占比最高的那一类。
关键源代码:
public static List<KNNmodel> calculate(List<KNNmodel> lists,KNNmodel newstudent,int k){
for (KNNmodel m:lists){
m.distance=Math.abs(m.height-newstudent.height);
}
Collections.sort(lists, Comparator.comparing(KNNmodel::getDistance)); System.out.println("按距离排序后:");
for (KNNmodel list : lists) {
System.out.println(list);}
List<KNNmodel> l=new ArrayList<>(k);
for (int i=0;i<k;i++){
l.add(lists.get(i));}
return l;
}
-
实验结果与分析
图2-3:距离最近的5个数据和结果图
只考虑高度一个维度,因此两个对象的距离计算直接使用了它们的差的绝对值得到。 如果在多维情况下,应该使用合适的距离公式,如欧式距离来计算和比较。在与待分类的元组最近的前五个数据中,属于中等的数据最多,按knn算法的思想,该待分类元组也属于中等类别。
-
小结与心得体会
算法简单好用,容易理解,精度高,理论成熟,既可以用来做分类也可以用来做回归,可用于数值型数据和离散型数据;但计算复杂性高,空间复杂性高,k值的选择对KNN算法的结果产生重大影响:
选择较小的k值,就相当于用较小的区域中的训练实例进行预测,训练误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是K值的减小就意味着整体模型变得复杂,容易发生过拟合。
选择较大的k值,就相当于用较大区域中的训练实例进行预测,其优点是可以减少泛化误差,但缺点是训练误差会增大。这时候,与输入实例较远(不相似的)训练实例也会对预测器作用,使预测发生错误,且K值的增大就意味着整体的模型变得简单。
一个极端的情况是k等于样本数m,此时完全没有分类,此时无论输入实例是什么,都只是简单的预测它属于在训练实例中最多的类,模型过于简单。
既然要找到k个最近的邻居来做预测,那么只需要计算预测样本和所有训练集中的样本的距离,然后计算出最小的k个距离即可,接着多数表决,很容易做出预测。这个方法的确简单直接,在样本量少,样本特征少的时候有效。这个方法我们一般称之为蛮力实现。比较适合于少量样本的简单模型的时候用。
源代码:
import java.util.*;
public class Knn {
/**
* 数据模型
*/
public static class KNNmodel{
public String name;
public int xuhao;
public double height;
public String type;
public double distance;
public double getDistance(){
return distance;
}
public KNNmodel(int xuhao,String name,double height,String type){
this.height=height;
this.name=name;
this.xuhao=xuhao;
this.type=type;
}
public String toString() {
return "KNNmodel{" +
" xuhao='" + xuhao + '\'' +
",name='" + name + '\'' +
", height=" + height +
", type='" + type + '\'' +
", distance=" + distance +
'}';
}
}
/**
* 计算距离并排序,取前K个数据
* @param lists
* @param newstudent
* @param k
* @return
*/
public static List<KNNmodel> calculate(List<KNNmodel> lists,KNNmodel newstudent,int k){
for (KNNmodel m:lists){
m.distance=Math.abs(m.height-newstudent.height);
}
Collections.sort(lists, Comparator.comparing(KNNmodel::getDistance));//按照每个元素的距离(即 distance 属性)进行升序排序
System.out.println("按距离排序后:");
for (KNNmodel list : lists) {
System.out.println(list);
}
List<KNNmodel> l=new ArrayList<>(k);
for (int i=0;i<k;i++){
l.add(lists.get(i));
}
return l;
}
//根据传入的列表 lists 中的数据统计每个类型出现的次数,并返回出现次数最多的类型
public static String findtype(List<KNNmodel> lists){
Map<String,Integer> map=new HashMap<>();
//通过遍历 lists 列表中的每个 KNNmodel 对象来更新 map 中的数据。
for (KNNmodel m:lists){
int sum= map.get(m.type)==null?1:map.get(m.type)+1;
map.put(m.type,sum);
}
System.out.println(map.toString());
List<Map.Entry<String,Integer>> list=new ArrayList<>(map.entrySet());
Collections.sort(list,Comparator.comparing(Map.Entry::getValue));//按照每个元素的值(即出现次数)进行升序排序
return list.get(list.size()-1).getKey();// 获取排序后的最后一个元素(出现次数最多的类型)
}
public static void main(String[] args) {
List<KNNmodel> lists=new ArrayList<KNNmodel>();
lists.add(new KNNmodel(1,"李丽",1.5,"矮"));
lists.add(new KNNmodel(2,"吉米",1.92,"高"));
lists.add(new KNNmodel(3,"马大华",1.7,"中等"));
lists.add(new KNNmodel(4,"王晓华",1.73,"中等"));
lists.add(new KNNmodel(5,"刘敏",1.6,"矮"));
lists.add(new KNNmodel(6,"张强",1.75,"中等"));
lists.add(new KNNmodel(7,"李秦",1.6,"矮"));
lists.add(new KNNmodel(8,"王壮",1.9,"高"));
lists.add(new KNNmodel(9,"刘冰",1.68,"中等"));
lists.add(new KNNmodel(10,"张喆",1.78,"中等"));
lists.add(new KNNmodel(11,"杨毅",1.70,"中等"));
lists.add(new KNNmodel(12,"徐田",1.68,"中等"));
lists.add(new KNNmodel(13,"高杰",1.65,"矮"));
lists.add(new KNNmodel(14,"张晓",1.78,"中等"));
KNNmodel newstudent=new KNNmodel(15,"易昌",1.70,"null");
int k=5;
List<KNNmodel> mindistance = calculate(lists, newstudent, k);
System.out.println("取前"+k+"个元组:");
for (KNNmodel knNmodel : mindistance) {
System.out.println(knNmodel);
}
String type = findtype(mindistance);
System.out.println(type);
}
}