KNN分类器及实现



出处:http://blog.csdn.net/zhongkejingwang/article/details/44132771

   KNN即K-Nearest Neighbor,是数据挖掘中一种最简单的分类方法,即要判断某一个样本属于已知样本种类中的哪一类时,通过计算找出所有样本中与测试样本最近或者最相似的K个样本,统计这K个样本中哪一种类最多则把测试样本归位该类。如何衡量两个样本的相似度?可以用向量的p-范数来定义。

假设有两个样本X=(x1, x2, ..., xn),Y=(y1, y2, ..., yn),则他们之间的相似度可以用以下向量p-范数定义:

                

当p=2时即为计算X、Y的欧几里得距离。

  本文将介绍用Java实现KNN分类器对Iris数据进行分类。Iris数据如下:


  前面四个item是属性,最后一个是类别名,总共有三类。完整的数据集可点击这里下载

拿到原始数据后为了测试KNN分类效果,需要在原始数据中随机抽取一部分作为测试集,另一部分作为训练集。随机抽取的方法可以用下面代码实现:

[java]   view plain copy
  1. /** 
  2.      * 将数据集划分为训练集和测试集,随机划分 
  3.      *  
  4.      * @param filePath 
  5.      *            数据集文件路径 
  6.      * @param testCount 
  7.      *            测试集个数 
  8.      * @param outputPath 
  9.      *            输出路径 
  10.      * @throws Exception 
  11.      */  
  12.     public static void splitDataSet(String filePath, int testCount,  
  13.             String outputPath) throws Exception  
  14.     {  
  15.         BufferedWriter trainFile = new BufferedWriter(new FileWriter(new File(  
  16.                 outputPath + "/train.txt")));  
  17.         BufferedWriter testFile = new BufferedWriter(new FileWriter(new File(  
  18.                 outputPath + "/test.txt")));  
  19.         BufferedReader input = new BufferedReader(new FileReader(new File(  
  20.                 filePath)));  
  21.         List<String> lines = new ArrayList<String>();  
  22.         String line = null;  
  23.         //将所有数据读取到一个List里  
  24.         while ((line = input.readLine()) != null)  
  25.             lines.add(line);  
  26.         //遍历一次List,每次产生一个随机序号,将该随机序号和当前序号内容进行交换  
  27.         for (int i = 0; i < lines.size(); i++)  
  28.         {  
  29.             int ran = (int) (Math.random() * lines.size());  
  30.             String temp = lines.get(i);  
  31.             lines.set(i, lines.get(ran));  
  32.             lines.set(ran, temp);  
  33.         }  
  34.         int i = 0;  
  35.         //将指定数目的测试集写进test.txt中  
  36.         for (; i < testCount; i++)  
  37.         {  
  38.             testFile.write(lines.get(i) + "\n");  
  39.             testFile.flush();  
  40.         }  
  41.         //剩余的写进train.txt中  
  42.         for (; i < lines.size(); i++)  
  43.         {  
  44.             trainFile.write(lines.get(i) + "\n");  
  45.             trainFile.flush();  
  46.         }  
  47.         testFile.close();  
  48.         trainFile.close();  
  49.     }  

   调用这个方法后就可以得到train.txt和test.txt两份数据了。

  接下来将数据读入:

[java]   view plain copy
  1. /** 
  2.      * 根据文件生成训练集,注意:程序将以第一个出现的非数字的属性作为类别名称 
  3.      *  
  4.      * @param fileName 
  5.      *            文件名 
  6.      * @param sep 
  7.      *            分隔符 
  8.      * @return 
  9.      * @throws Exception 
  10.      */  
  11.     public List<DataNode> getDataList(String fileName, String sep)  
  12.             throws Exception  
  13.     {  
  14.         List<DataNode> list = new ArrayList<DataNode>();  
  15.         BufferedReader br = new BufferedReader(new FileReader(  
  16.                 new File(fileName)));  
  17.         String line = null;  
  18.         while ((line = br.readLine()) != null)  
  19.         {  
  20.             String splits[] = line.split(sep);  
  21.             //DataNode类用于保存数据属性和数据类别  
  22.             DataNode node = new DataNode();  
  23.             int i = 0;  
  24.             for (; i < splits.length; i++)  
  25.             {  
  26.                 try  
  27.                 {  
  28.                     node.addAttrib(Float.valueOf(splits[i]));  
  29.                 } catch (NumberFormatException e)  
  30.                 {  
  31.                     // 非数字,则为类别名称,将类别映射为数字  
  32.                     if (!mTypes.containsKey(splits[i]))  
  33.                     {  
  34.                         mTypes.put(splits[i], mTypeCount);  
  35.                         mTypeCount++;  
  36.                     }  
  37.                     node.setType(mTypes.get(splits[i]));  
  38.                     list.add(node);  
  39.                 }  
  40.             }  
  41.         }  
  42.         return list;  
  43.     }  
对于testList中的每一个样本,均与所有trainList中的样本进行计算,取出最接近的K个样本并返回:

KnnClassifier.java

[java]   view plain copy
  1. package com.jingchen.knn;  
  2.   
  3. import java.util.List;  
  4.   
  5. /** 
  6.  * @author chenjing 
  7.  *  
  8.  */  
  9. public class KnnClassifier  
  10. {  
  11.     //k个近邻节点  
  12.     private int k;  
  13.     private KNode[] mNearestK;  
  14.     private List<DataNode> mTrainData;  
  15.   
  16.     public KnnClassifier(int k, List<DataNode> trainList)  
  17.     {  
  18.         mTrainData = trainList;  
  19.         this.k = k;  
  20.         mNearestK = new KNode[k];  
  21.         for (int i = 0; i < k; i++)  
  22.             mNearestK[i] = new KNode();  
  23.     }  
  24.     public void setK(int k){  
  25.         this.k = k;  
  26.         mNearestK = new KNode[k];  
  27.         for (int i = 0; i < k; i++)  
  28.             mNearestK[i] = new KNode();  
  29.     }  
  30.     private void train(DataNode test, float p)  
  31.     {  
  32.         for (int i = 0; i < mTrainData.size(); i++)  
  33.         {  
  34.             putNode(getSim(test, mTrainData.get(i), p));  
  35.         }  
  36.     }  
  37.   
  38.     /** 
  39.      * 将新计算出来的节点与k个近邻节点比较,如果比其中之一小则插入 
  40.      * @param node 
  41.      */  
  42.     private void putNode(KNode node)  
  43.     {  
  44.         for (int i = 0; i < k; i++)  
  45.         {  
  46.             if (node.getD() < mNearestK[i].getD())  
  47.             {  
  48.                 for (int j = k - 1; j > i; j--)  
  49.                     mNearestK[j] = mNearestK[j - 1];  
  50.                 mNearestK[i] = node;  
  51.                 break;  
  52.             }  
  53.         }  
  54.     }  
  55.   
  56.     /** 
  57.      * 获取相似度并封装成一个KNode类型返回 
  58.      * @param test 
  59.      * @param trainNode 
  60.      * @param p 
  61.      * @return 
  62.      */  
  63.     private KNode getSim(DataNode test, DataNode trainNode, float p)  
  64.     {  
  65.         List<Float> list1 = test.getAttribs();  
  66.         List<Float> list2 = trainNode.getAttribs();  
  67.         float d = 0;  
  68.         for (int i = 0; i < list1.size(); i++)  
  69.             d += Math.pow(  
  70.                     Math.abs(list1.get(i).floatValue() - list2.get(i).floatValue()), p);  
  71.         d = (float) Math.pow(d, 1/p);  
  72.         KNode node = new KNode(d, trainNode.getType());  
  73.         return node;  
  74.     }  
  75.   
  76.     private void reset()  
  77.     {  
  78.         for (int i = 0; i < k; i++)  
  79.             mNearestK[i].reset();  
  80.     }  
  81.   
  82.     /** 
  83.      * 返回K个近邻节点 
  84.      * @param test 
  85.      * @param p 
  86.      * @return 
  87.      */  
  88.     public KNode[] getKNN(DataNode test, float p)  
  89.     {  
  90.         reset();  
  91.         train(test, p);  
  92.         return mNearestK;  
  93.     }  
  94. }  

  main方法:

[java]   view plain copy
  1. public static void main(String[] args) throws Exception  
  2.     {  
  3.         DataUtil util = DataUtil.getInstance();  
  4.         //获得训练集和测试集  
  5.         List<DataNode> trainList = util.getDataList("E:/train.txt"",");  
  6.         List<DataNode> testList = util.getDataList("E:/test.txt"",");  
  7.         int K = BASE_K;  
  8.         KnnClassifier classifier = new KnnClassifier(K, trainList);  
  9.         BufferedWriter output = new BufferedWriter(new FileWriter(new File(  
  10.                 "E:/output.txt")));  
  11.         int typeCount = util.getTypeCount();  
  12.         int[] count = new int[typeCount];  
  13.         for (int i = 0; i < testList.size();)  
  14.         {  
  15.             for (int m = 0; m < typeCount; m++)  
  16.                 count[m] = 0;  
  17.             DataNode test = testList.get(i);  
  18.             classifier.setK(K);  
  19.             KNode[] nodes = classifier.getKNN(test, 2);  
  20.             for (int j = 0; j < nodes.length; j++)  
  21.                 count[nodes[j].getType()]++;  
  22.             int type = -1;  
  23.             int max = -1;  
  24.             for (int j = 0; j < typeCount; j++)  
  25.             {  
  26.                 if (count[j] > max)  
  27.                 {  
  28.                     max = count[j];  
  29.                     type = j;  
  30.                 } else if (count[j] == max)  
  31.                 {  
  32.                     // 存在两个类型分个数相同,无法判断属于哪个类型,增加K的值继续从该节点开始  
  33.                     type = -1;  
  34.                     K++;  
  35.                     break;  
  36.                 }  
  37.             }  
  38.             if (type == -1)  
  39.                 continue;  
  40.             else  
  41.             {  
  42.                 i++;  
  43.                 K = BASE_K;  
  44.             }  
  45.             //将分类结果写入文件  
  46.             List<Float> attribs = test.getAttribs();  
  47.             for (int n = 0; n < attribs.size(); n++)  
  48.             {  
  49.                 output.write(attribs.get(n) + ",");  
  50.                 output.flush();  
  51.             }  
  52.             output.write(util.getTypeName(type) + "\n");  
  53.             output.flush();  
  54.         }  
  55.         output.close();  
  56.   
  57.     }  

   经测试,KNN对Iris数据集分类准确率基本都在90+%以上,此分类方法也比较直观。数据集及完整的项目代码可以从这里下载: 点击下载


  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值