最近邻分类算法思想
KNN算法的思想总结一下:就是在训练集中数据和标签已知的情况下,输入测试数据,将测试数据的特征与训练集中对应的特征进行相互比较,找到训练集中与之最为相似的前K个数据,则该测试数据对应的类别就是K个数据中出现次数最多的那个分类,其算法的描述为:
1)计算测试数据与各个训练数据之间的距离;
2)按照距离的递增关系进行排序;
3)选取距离最小的K个点;
4)确定前K个点所在类别的出现频率;
5)返回前K个点中出现频率最高的类别作为测试数据的预测分类。
Java代码实现
KNN.java代码
public class KNN {
public static void main(String[] args) {
// 一、输入所有已知点
List<Point>dataList = creatDataSet();
// 二、输入未知点
Point x = new Point(5, 1.2, 1.2);
// 三、计算所有已知点到未知点的欧式距离,并根据距离对所有已知点排序
CompareClass compare = new CompareClass();
Set<Distance> distanceSet = new TreeSet<Distance>(compare);
for (Pointpoint : dataList) {
distanceSet.add(new Distance(point.getId(), x.getId(), oudistance(point,
x)));
}
// 四、选取最近的k个点
double k = 5;
/**
* 五、计算k个点所在分类出现的频率
*/
// 1、计算每个分类所包含的点的个数
List<Distance> distanceList= new ArrayList<Distance>(distanceSet);
Map<String, Integer> map = getNumberOfType(distanceList, dataList, k);
// 2、计算频率
Map<String, Double> p = computeP(map, k);
x.setType(maxP(p));
System.out.println("未知点的类型为:"+x.getType());
}
// 欧式距离计算
public static double oudistance(Point point1, Pointpoint2) {
double temp = Math.pow(point1.getX() - point2.getX(), 2)
+ Math.pow(point1.getY() - point2.getY(), 2);
return Math.sqrt(temp);
}
// 找出最大频率
public static String maxP(Map<String,Double> map) {
String key = null;
double value = 0.0;
for (Map.Entry<String, Double> entry : map.entrySet()) {
if (entry.getValue() > value) {
key = entry.getKey();
value = entry.getValue();
}
}
return key;
}
// 计算频率
public static Map<String,Double> computeP(Map<String, Integer> map,
double k) {
Map<String, Double> p = new HashMap<String, Double>();
for (Map.Entry<String, Integer> entry : map.entrySet()) {
p.put(entry.getKey(), entry.getValue() / k);
}
return p;
}
// 计算每个分类包含的点的个数
public static Map<String,Integer> getNumberOfType(
List<Distance> listDistance, List<Point> listPoint, double k) {
Map<String, Integer> map = new HashMap<String, Integer>();
int i = 0;
System.out.println("选取的k个点,由近及远依次为:");
for (Distance distance : listDistance) {
System.out.println("id为" + distance.getId() + ",距离为:"
+ distance.getDisatance());
long id = distance.getId();
// 通过id找到所属类型,并存储到HashMap中
for (Point point : listPoint) {
if (point.getId() == id) {
if (map.get(point.getType()) != null)
map.put(point.getType(), map.get(point.getType()) + 1);
else {
map.put(point.getType(), 1);
}
}
}
i++;
if (i >= k)
break;
}
return map;
}
public static ArrayList<Point> creatDataSet(){
Point point1 = new Point(1, 1.0, 1.1, "A");
Point point2 = new Point(2, 1.0, 1.0, "A");
Point point3 = new Point(3, 1.0, 1.2, "A");
Point point4 = new Point(4, 0, 0, "B");
Point point5 = new Point(5, 0, 0.1, "B");
Point point6 = new Point(6, 0, 0.2, "B");
ArrayList<Point>dataList = new ArrayList<Point>();
dataList.add(point1);
dataList.add(point2);
dataList.add(point3);
dataList.add(point4);
dataList.add(point5);
dataList.add(point6);
return dataList;
}
}
类中涉及到的Point类,Distance类,比较裁判CompareClass类如下:
Point类
public class Point {
private long id;
private double x;
private double y;
private String type;
public Point(long id,double x, double y) {
this.x =x;
this.y =y;
this.id =id;
}
public Point(long id,double x, double y, String type) {
this.x =x;
this.y =y;
this.type= type;
this.id =id;
}
//get、set方法省略
}
Distance类
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | public class Distance { // 已知点id private long id; // 未知点id private long nid; // 二者之间的距离 private double disatance;
public Distance(long id, long nid, double disatance) { this.id = id; this.nid = nid; this.disatance = disatance; }
//get、set方法省略
} |
比较器CompareClass类
1 2 3 4 5 6 7 8 9 | import java.util.Comparator; //比较器类 public class CompareClass implements Comparator<Distance>{
public int compare(Distance d1, Distance d2) { return d1.getDisatance()>d2.getDisatance()?20 : -1; }
} |
其中的计算Map<String,Double> typeAndDistance按照distance进行排序,也就是按照map的value进行排序。思路也可以用如下方法:
1. public class Testing {
2.
3. public static void main(String[] args) {
4.
5. HashMap<String,Double> map = new HashMap<String,Double>();
6. ValueComparator bvc = new ValueComparator(map);
7. TreeMap<String,Double> sorted_map = new TreeMap<String,Double>(bvc);
8.
9. map.put("A",99.5);
10. map.put("B",67.4);
11. map.put("C",67.4);
12. map.put("D",67.3);
13.
14. System.out.println("unsorted map: "+map);
15.
16. sorted_map.putAll(map);
17.
18. System.out.println("results: "+sorted_map);
19. }
20. }
21.
22. class ValueComparator implements Comparator<String> {
23.
24. Map<String, Double> base;
25. public ValueComparator(Map<String, Double> base) {
26. this.base = base;
27. }
28.
29. // Note: this comparator imposes orderings that are inconsistent with equals.
30. public int compare(String a, String b) {
31. if (base.get(a) >= base.get(b)) {
32. return -1;
33. } else {
34. return 1;
35. } // returning 0 would merge keys
36. }
37. }