java - KNN算法

由于看网上的java有点多,自己写了一份,本人也是初学者,有错误请提出,大家一起学习。

 1 import java.io.BufferedReader;
 2 import java.io.File;
 3 import java.io.FileNotFoundException;
 4 import java.io.FileReader;
 5 import java.io.IOException;
 6 import java.util.*;
 7 
 8 
 9 public class Index {
10     public static void main(String[] args){
11         List<List<Double>> Filedatas = new ArrayList<List<Double>>();
12         List<List<Double>> Testdatas = new ArrayList<List<Double>>();
13         
14         readFile(Filedatas,Testdatas);
15         KNN knn = new KNN();
16         
17         for(int i=0;i<Filedatas.size();i++){
18             String s = knn.comdistance(3,Filedatas,Testdatas.get(i));
19             print(s,Testdatas.get(i));
20         }
21     }
22     //第4步、打印出结果
23     private static void print(String s,List<Double> testdata) {
24         System.out.print("测试数据:");
25         for(int i=0;i<testdata.size();i++){
26             System.out.print(testdata.get(i) + " ");
27         }
28         int label = Math.round(Float.parseFloat(s));
29         System.out.println("所属类别:" + label);
30     }
31 
32     //第1.1步、读取文件
33     private static void readFile(List<List<Double>> Filedatas, List<List<Double>> Testdatas) {
34         try {
35             BufferedReader bfd = new BufferedReader(new FileReader(new File("D://a.txt")));
36             Filedatas = read(bfd,Filedatas);
37             BufferedReader bft = new BufferedReader(new FileReader(new File("D://b.txt")));
38             Testdatas = read(bft,Testdatas);
39         } catch (FileNotFoundException e) {
40             e.printStackTrace();
41         }        
42     }
43 
44     //第1.2步、读取文件
45     private static List<List<Double>> read(BufferedReader bf, List<List<Double>> datas) {
46         try {
47             String str = bf.readLine();
48             while(str != null){
49                 List<Double> d = new ArrayList<Double>();
50                 String[] string = str.split(" "); 
51                 for (String s : string) {
52                     d.add(Double.parseDouble(s));
53                 }
54                 datas.add(d);
55                 str = bf.readLine();
56             }
57         } catch (IOException e) {
58             e.printStackTrace();
59         }
60         return datas;
61     }
62 
63     
64 }
 1 import java.util.Comparator;
 2 import java.util.HashMap;
 3 import java.util.List;
 4 import java.util.Map;
 5 import java.util.PriorityQueue;
 6 
 7 public class KNN {
 8     
 9     public String comdistance(int k, List<List<Double>> filedatas,List<Double> testdata) {
10         //第2.1步、对加入queue队列的项进行距离的排序
11         PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k,new Comparator<KNNNode>() {    //优先级队列,按照distance的大小进行排列
12             @Override
13             public int compare(KNNNode o1, KNNNode o2) {
14                 if(o1.getDistance() >= o2.getDistance()){
15                     return -1;
16                 }
17                 else{
18                     return 1;
19                 }
20             }
21         });
22         //第2.2步、计算测试点与训练点的距离,并add进队列,挑出与测试点距离最近的K个点
23         for(int i=0;i<k;i++){
24             double distance = 0;
25             for(int j=0;j<filedatas.get(i).size()-1;j++){
26                 distance += (filedatas.get(i).get(j) - testdata.get(j)) * (filedatas.get(i).get(j) - testdata.get(j));
27             }
28             KNNNode node = new KNNNode(filedatas.get(i).get(filedatas.get(i).size()-1).toString(),distance);
29             pq.add(node);
30         }
31         for(int i=k;i<filedatas.size();i++){
32             double distance = 0;
33             for(int j=0;j<filedatas.get(i).size()-1;j++){
34                 distance += (filedatas.get(i).get(j) - testdata.get(j)) * (filedatas.get(i).get(j) - testdata.get(j));
35             }
36             KNNNode node = new KNNNode(filedatas.get(i).get(filedatas.get(i).size()-1).toString(),distance);
37             if(pq.peek().getDistance() >= distance ){
38                 pq.remove();
39                 pq.add(node);
40             }
41         }
42         String s = decide(pq);
43         return s;
44     }
45     //第3步、把选择好的最近的K个点的类别进行比较,多的即是测试点的类别
46     private String decide(PriorityQueue<KNNNode> pq) {
47         Map<String,Integer> m = new HashMap<String,Integer>();
48         for (KNNNode Node : pq) {
49             if(m.containsKey(Node.getC())){
50                 m.put(Node.getC(), m.get(Node.getC()) + 1);
51             }
52             else{
53                 m.put(Node.getC(), 1);
54             }
55         }
56         Object[] o = m.keySet().toArray();
57 
58         if(o.length == 1){
59             return o[0].toString();
60         }
61         else{
62             for(int i=0;i<o.length;i++){
63                 for(int j=i;j<o.length;j++){
64                     if(i != j){
65                         if(m.get(o[i]) > m.get(o[j])){
66                             return o[i].toString();
67                         }
68                         else{
69                             return o[j].toString();
70                         }
71                     }
72                 }
73             }
74         }
75         return null;
76     }
77 }
 1 public class KNNNode {
 2     
 3     private String c;
 4     private double distance;
 5     
 6     public KNNNode(String c, double distance) {
 7         super();
 8         this.c = c;
 9         this.distance = distance;
10     }
11     
12     public String getC() {
13         return c;
14     }
15     public double getDistance() {
16         return distance;
17     }
18     public void setC(String c) {
19         this.c = c;
20     }
21     public void setDistance(double distance) {
22         this.distance = distance;
23     }
24 }

训练数据:

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0

测试数据:

1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5
1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8
1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2
1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5
1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5
1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5

转载于:https://www.cnblogs.com/wn19910213/p/3325477.html

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值