数据挖掘:K最近邻(KNN)算法的java实现

1.急切学习与懒惰学习

急切学习:在给定训练元组之后、接收到测试元组之前就构造好泛化(即分类)模型。

属于急切学习的算法有:决策树、贝叶斯、基于规则的分类、后向传播分类、SVM和基于关联规则挖掘的分类等等。

 

懒惰学习:直至给定一个测试元组才开始构造泛化模型,也称为基于实例的学习法。

属于急切学习的算法有:KNN分类、基于案例的推理分类。

 

2.KNN的优缺点

优点:原理简单,实现起来比较方便。支持增量学习。能对超多边形的复杂决策空间建模。

缺点:计算开销大,需要有效的存储技术和并行硬件的支撑。

 

3.KNN算法原理

基于类比学习,通过比较训练元组和测试元组的相似度来学习。

将训练元组和测试元组看作是n维(若元组有n的属性)空间内的点,给定一条测试元组,搜索n维空间,找出与测试

元组最相近的k个点(即训练元组),最后取这k个点中的多数类作为测试元组的类别。

 

相近的度量方法:用空间内两个点的距离来度量。距离越大,表示两个点越不相似。

 

距离的选择:可采用欧几里得距离、曼哈顿距离或其它距离度量。多采用欧几里得距离,简单!

 

4.KNN算法中的细节处理

  • 数值属性规范化:将数值属性规范到0-1区间以便于计算,也可防止大数值型属性对分类的主导作用。

可选的方法有:v' = (v - vmin)/ (vmax - vmin),当然也可以采用其它的规范化方法

  • 比较的属性是分类类型而不是数值类型的:同则差为0,异则差为1.

有时候可以作更为精确的处理,比如黑色与白色的差肯定要大于灰色与白色的差。

  • 缺失值的处理:取最大的可能差,对于分类属性,如果属性A的一个或两个对应值丢失,则取差值为1;

如果A是数值属性,若两个比较的元组A属性值均缺失,则取差值为1,若只有一个缺失,另一个值为v,

则取差值为|1-v|和|0-v|中的最大值

  • 确定K的值:通过实验确定。进行若干次实验,取分类误差率最小的k值。
  • 对噪声数据或不相关属性的处理:对属性赋予相关性权重w,w越大说明属性对分类的影响越相关。对噪声数据可以将所在

的元组直接cut掉。

 

5.KNN算法流程

  • 准备数据,对数据进行预处理
  • 选用合适的数据结构存储训练数据和测试元组
  • 设定参数,如k
  • 维护一个大小为k的的按距离由大到小的优先级队列,用于存储最近邻训练元组
  • 随机从训练元组中选取k个元组作为初始的最近邻元组,分别计算测试元组到这k个元组的距离,将训练元组标号和距离存入优先级队列
  • 遍历训练元组集,计算当前训练元组与测试元组的距离,将所得距离L与优先级队列中的最大距离Lmax进行比较。若L>=Lmax,则舍弃该元组,遍历下一个元组。若L < Lmax,删除优先级队列中最大距离的元组,将当前训练元组存入优先级队列。
  • 遍历完毕,计算优先级队列中k个元组的多数类,并将其作为测试元组的类别。
  • 测试元组集测试完毕后计算误差率,继续设定不同的k值重新进行训练,最后取误差率最小的k值。

6.KNN算法的改进策略

  • 将存储的训练元组预先排序并安排在搜索树中(如何排序有待研究)
  • 并行实现
  • 部分距离计算,取n个属性的“子集”计算出部分距离,若超过设定的阈值则停止对当前元组作进一步计算。转向下一个元组。
  • 剪枝或精简:删除证明是“无用的”元组。

7.KNN算法java实现


本算法只适合学习使用,可以大致了解一下KNN算法的原理。

 

算法作了如下的假定与简化处理:

1.小规模数据集

2.假设所有数据及类别都是数值类型的

3.直接根据数据规模设定了k值

4.对原训练集进行测试

 

KNN实现代码如下:

  1. package KNN;  
  2. /** 
  3.  * KNN结点类,用来存储最近邻的k个元组相关的信息 
  4.  * @author Rowen 
  5.  * @qq 443773264 
  6.  * @mail luowen3405@163.com 
  7.  * @blog blog.csdn.net/luowen3405 
  8.  * @data 2011.03.25 
  9.  */  
  10. public class KNNNode {  
  11.     private int index; // 元组标号  
  12.     private double distance; // 与测试元组的距离  
  13.     private String c; // 所属类别  
  14.     public KNNNode(int index, double distance, String c) {  
  15.         super();  
  16.         this.index = index;  
  17.         this.distance = distance;  
  18.         this.c = c;  
  19.     }  
  20.       
  21.       
  22.     public int getIndex() {  
  23.         return index;  
  24.     }  
  25.     public void setIndex(int index) {  
  26.         this.index = index;  
  27.     }  
  28.     public double getDistance() {  
  29.         return distance;  
  30.     }  
  31.     public void setDistance(double distance) {  
  32.         this.distance = distance;  
  33.     }  
  34.     public String getC() {  
  35.         return c;  
  36.     }  
  37.     public void setC(String c) {  
  38.         this.c = c;  
  39.     }  
  40. }  
 

 

  1. package KNN;  
  2. import java.util.ArrayList;  
  3. import java.util.Comparator;  
  4. import java.util.HashMap;  
  5. import java.util.List;  
  6. import java.util.Map;  
  7. import java.util.PriorityQueue;  
  8.   
  9. /** 
  10.  * KNN算法主体类 
  11.  * @author Rowen 
  12.  * @qq 443773264 
  13.  * @mail luowen3405@163.com 
  14.  * @blog blog.csdn.net/luowen3405 
  15.  * @data 2011.03.25 
  16.  */  
  17. public class KNN {  
  18.     /** 
  19.      * 设置优先级队列的比较函数,距离越大,优先级越高 
  20.      */  
  21.     private Comparator<KNNNode> comparator = new Comparator<KNNNode>() {  
  22.         public int compare(KNNNode o1, KNNNode o2) {  
  23.             if (o1.getDistance() >= o2.getDistance()) {  
  24.                 return 1;  
  25.             } else {  
  26.                 return 0;  
  27.             }  
  28.         }  
  29.     };  
  30.     /** 
  31.      * 获取K个不同的随机数 
  32.      * @param k 随机数的个数 
  33.      * @param max 随机数最大的范围 
  34.      * @return 生成的随机数数组 
  35.      */  
  36.     public List<Integer> getRandKNum(int k, int max) {  
  37.         List<Integer> rand = new ArrayList<Integer>(k);  
  38.         for (int i = 0; i < k; i++) {  
  39.             int temp = (int) (Math.random() * max);  
  40.             if (!rand.contains(temp)) {  
  41.                 rand.add(temp);  
  42.             } else {  
  43.                 i--;  
  44.             }  
  45.         }  
  46.         return rand;  
  47.     }  
  48.     /** 
  49.      * 计算测试元组与训练元组之前的距离 
  50.      * @param d1 测试元组 
  51.      * @param d2 训练元组 
  52.      * @return 距离值 
  53.      */  
  54.     public double calDistance(List<Double> d1, List<Double> d2) {  
  55.         double distance = 0.00;  
  56.         for (int i = 0; i < d1.size(); i++) {  
  57.             distance += (d1.get(i) - d2.get(i)) * (d1.get(i) - d2.get(i));  
  58.         }  
  59.         return distance;  
  60.     }  
  61.     /** 
  62.      * 执行KNN算法,获取测试元组的类别 
  63.      * @param datas 训练数据集 
  64.      * @param testData 测试元组 
  65.      * @param k 设定的K值 
  66.      * @return 测试元组的类别 
  67.      */  
  68.     public String knn(List<List<Double>> datas, List<Double> testData, int k) {  
  69.         PriorityQueue<KNNNode> pq = new PriorityQueue<KNNNode>(k, comparator);  
  70.         List<Integer> randNum = getRandKNum(k, datas.size());  
  71.         for (int i = 0; i < k; i++) {  
  72.             int index = randNum.get(i);  
  73.             List<Double> currData = datas.get(index);  
  74.             String c = currData.get(currData.size() - 1).toString();  
  75.             KNNNode node = new KNNNode(index, calDistance(testData, currData), c);  
  76.             pq.add(node);  
  77.         }  
  78.         for (int i = 0; i < datas.size(); i++) {  
  79.             List<Double> t = datas.get(i);  
  80.             double distance = calDistance(testData, t);  
  81.             KNNNode top = pq.peek();  
  82.             if (top.getDistance() > distance) {  
  83.                 pq.remove();  
  84.                 pq.add(new KNNNode(i, distance, t.get(t.size() - 1).toString()));  
  85.             }  
  86.         }  
  87.           
  88.         return getMostClass(pq);  
  89.     }  
  90.     /** 
  91.      * 获取所得到的k个最近邻元组的多数类 
  92.      * @param pq 存储k个最近近邻元组的优先级队列 
  93.      * @return 多数类的名称 
  94.      */  
  95.     private String getMostClass(PriorityQueue<KNNNode> pq) {  
  96.         Map<String, Integer> classCount = new HashMap<String, Integer>();  
  97.         for (int i = 0; i < pq.size(); i++) {  
  98.             KNNNode node = pq.remove();  
  99.             String c = node.getC();  
  100.             if (classCount.containsKey(c)) {  
  101.                 classCount.put(c, classCount.get(c) + 1);  
  102.             } else {  
  103.                 classCount.put(c, 1);  
  104.             }  
  105.         }  
  106.         int maxIndex = -1;  
  107.         int maxCount = 0;  
  108.         Object[] classes = classCount.keySet().toArray();  
  109.         for (int i = 0; i < classes.length; i++) {  
  110.             if (classCount.get(classes[i]) > maxCount) {  
  111.                 maxIndex = i;  
  112.                 maxCount = classCount.get(classes[i]);  
  113.             }  
  114.         }  
  115.         return classes[maxIndex].toString();  
  116.     }  
  117. }  
 

 

  1. package KNN;  
  2. import java.io.BufferedReader;  
  3. import java.io.File;  
  4. import java.io.FileReader;  
  5. import java.util.ArrayList;  
  6. import java.util.List;  
  7. /** 
  8.  * KNN算法测试类 
  9.  * @author Rowen 
  10.  * @qq 443773264 
  11.  * @mail luowen3405@163.com 
  12.  * @blog blog.csdn.net/luowen3405 
  13.  * @data 2011.03.25 
  14.  */  
  15. public class TestKNN {  
  16.       
  17.     /** 
  18.      * 从数据文件中读取数据 
  19.      * @param datas 存储数据的集合对象 
  20.      * @param path 数据文件的路径 
  21.      */  
  22.     public void read(List<List<Double>> datas, String path){  
  23.         try {  
  24.             BufferedReader br = new BufferedReader(new FileReader(new File(path)));  
  25.             String data = br.readLine();  
  26.             List<Double> l = null;  
  27.             while (data != null) {  
  28.                 String t[] = data.split(" ");  
  29.                 l = new ArrayList<Double>();  
  30.                 for (int i = 0; i < t.length; i++) {  
  31.                     l.add(Double.parseDouble(t[i]));  
  32.                 }  
  33.                 datas.add(l);  
  34.                 data = br.readLine();  
  35.             }  
  36.         } catch (Exception e) {  
  37.             e.printStackTrace();  
  38.         }  
  39.     }  
  40.       
  41.     /** 
  42.      * 程序执行入口 
  43.      * @param args 
  44.      */  
  45.     public static void main(String[] args) {  
  46.         TestKNN t = new TestKNN();  
  47.         String datafile = new File("").getAbsolutePath() + File.separator + "datafile";  
  48.         String testfile = new File("").getAbsolutePath() + File.separator + "testfile";  
  49.         try {  
  50.             List<List<Double>> datas = new ArrayList<List<Double>>();  
  51.             List<List<Double>> testDatas = new ArrayList<List<Double>>();  
  52.             t.read(datas, datafile);  
  53.             t.read(testDatas, testfile);  
  54.             KNN knn = new KNN();  
  55.             for (int i = 0; i < testDatas.size(); i++) {  
  56.                 List<Double> test = testDatas.get(i);  
  57.                 System.out.print("测试元组: ");  
  58.                 for (int j = 0; j < test.size(); j++) {  
  59.                     System.out.print(test.get(j) + " ");  
  60.                 }  
  61.                 System.out.print("类别为: ");  
  62.                 System.out.println(Math.round(Float.parseFloat((knn.knn(datas, test, 3)))));  
  63.             }  
  64.         } catch (Exception e) {  
  65.             e.printStackTrace();  
  66.         }  
  67.     }  
  68. }  
 

 

训练数据文件:

  1. 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 1  
  2. 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 1  
  3. 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 1  
  4. 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 0  
  5. 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 1  
  6. 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 0  
 

 

  1. 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5  
  2. 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8  
  3. 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2  
  4. 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5  
  5. 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5  
  6. 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5  
 

 

程序运行结果:

  1. 测试元组: 1.0 1.1 1.2 2.1 0.3 2.3 1.4 0.5 类别为: 1  
  2. 测试元组: 1.7 1.2 1.4 2.0 0.2 2.5 1.2 0.8 类别为: 1  
  3. 测试元组: 1.2 1.8 1.6 2.5 0.1 2.2 1.8 0.2 类别为: 1  
  4. 测试元组: 1.9 2.1 6.2 1.1 0.9 3.3 2.4 5.5 类别为: 0  
  5. 测试元组: 1.0 0.8 1.6 2.1 0.2 2.3 1.6 0.5 类别为: 1  
  6. 测试元组: 1.6 2.1 5.2 1.1 0.8 3.6 2.4 4.5 类别为: 0  
 

 

由结果可以看出,分类的测试结果是比较准确的!



转自:http://blog.csdn.net/luowen3405/article/details/6278764


参考:http://blog.csdn.net/xlm289348/article/details/8876353

http://coolshell.cn/articles/8052.html#more-8052

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值