数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文本分类器的JAVA实现(下)

本文接 数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文档分类器的JAVA实现(上) 

(update 2012.12.28 关于本项目下载及运行的常见问题 FAQ见 newsgroup18828文本分类器、文本聚类器、关联分析频繁模式挖掘算法的Java实现工程下载及运行FAQ )

上文中描述了newsgroup18828文档集的预处理及贝叶斯算法的JAVA实现,下面我们来看看如何实现基于KNN算法的newsgroup文本分类器

1 KNN算法的描述

KNN算法描述如下:
STEP ONE:文本向量化表示,由特征词的TF*IDF值计算
STEP TWO:在新文本到达后,根据特征词确定新文本的向量
STEP THREE:在训练文本集中选出与新文本最相似的 K 个文本,相似度用向量夹角余弦度量,计算公式为:


其中,K 值的确定目前没有很好的方法,一般采用先定一个初始值,然后根据实验测试的结果调整 K 值
本项目中K取20

STEP FOUR:在新文本的 K 个邻居中,依次计算每类的权重,每类的权重等于K个邻居中属于该类的训练样本与测试样本的相似度之和。
STEP FIVE:比较类的权重,将文本分到权重最大的那个类别中。

2 文档TF-IDF计算及向量化表示

实现KNN算法首先要实现文档的向量化表示
计算特征词的TF*IDF,每个文档的向量由包含所有特征词的TF*IDF值组成,每一维对应一个特征词

TF及IDF的计算公式如下,分别为特征词的特征项频率和逆文档频率


文档向量计算类 ComputeWordsVector.java如下

[java]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. package com.pku.yangliu;  
  2. import java.io.BufferedReader;  
  3. import java.io.File;  
  4. import java.io.FileReader;  
  5. import java.io.FileWriter;  
  6. import java.io.IOException;  
  7. import java.util.SortedMap;  
  8. import java.util.Map;  
  9. import java.util.Set;  
  10. import java.util.TreeMap;  
  11. import java.util.Iterator;  
  12.   
  13. /**计算文档的属性向量,将所有文档向量化 
  14.  * 
  15.  */  
  16. public class ComputeWordsVector {  
  17.       
  18.     /**计算文档的TF属性向量,直接写成二维数组遍历形式即可,没必要递归 
  19.      * @param strDir 处理好的newsgroup文件目录的绝对路径 
  20.      * @param trainSamplePercent 训练样例集占每个类目的比例 
  21.      * @param indexOfSample 测试样例集的起始的测试样例编号 
  22.      * @param wordMap 属性词典map 
  23.      * @throws IOException  
  24.      */  
  25.     public void computeTFMultiIDF(String strDir, double trainSamplePercent, int indexOfSample, Map<String, Double> iDFPerWordMap, Map<String, Double> wordMap) throws IOException{  
  26.         File fileDir = new File(strDir);  
  27.         String word;  
  28.         SortedMap<String,Double> TFPerDocMap = new TreeMap<String,Double>();  
  29.         //注意可以用两个写文件,一个专门写测试样例,一个专门写训练样例,用sampleType的值来表示  
  30.         String trainFileDir = "F:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+indexOfSample;  
  31.         String testFileDir = "F:/DataMiningSample/docVector/wordTFIDFMapTestSample"+indexOfSample;  
  32.         FileWriter tsTrainWriter = new FileWriter(new File(trainFileDir));  
  33.         FileWriter tsTestWrtier = new FileWriter(new File(testFileDir));  
  34.         FileWriter tsWriter = tsTrainWriter;  
  35.         File[] sampleDir = fileDir.listFiles();  
  36.         for(int i = 0; i < sampleDir.length; i++){  
  37.             String cateShortName = sampleDir[i].getName();  
  38.             System.out.println("compute: " + cateShortName);  
  39.             File[] sample = sampleDir[i].listFiles();  
  40.             double testBeginIndex = indexOfSample*(sample.length * (1-trainSamplePercent));//测试样例的起始文件序号  
  41.             double testEndIndex = (indexOfSample+1)*(sample.length * (1-trainSamplePercent));//测试样例集的结束文件序号  
  42.             System.out.println("dirName_total length:"+sampleDir[i].getCanonicalPath()+"_"+sample.length);  
  43.             System.out.println(trainSamplePercent + " length:"+sample.length * trainSamplePercent +" testBeginIndex:"+testBeginIndex+" testEndIndex"+ testEndIndex);      
  44.             for(int j = 0;j < sample.length; j++){  
  45.                 TFPerDocMap.clear();  
  46.                 FileReader samReader = new FileReader(sample[j]);  
  47.                 BufferedReader samBR = new BufferedReader(samReader);  
  48.                 String fileShortName = sample[j].getName();  
  49.                 Double wordSumPerDoc = 0.0;//计算每篇文档的总词数  
  50.                 while((word = samBR.readLine()) != null){  
  51.                     if(!word.isEmpty() && wordMap.containsKey(word)){//必须是属性词典里面的词,去掉的词不考虑  
  52.                         wordSumPerDoc++;  
  53.                         if(TFPerDocMap.containsKey(word)){  
  54.                             Double count =  TFPerDocMap.get(word);  
  55.                             TFPerDocMap.put(word, count + 1);  
  56.                         }  
  57.                         else {  
  58.                             TFPerDocMap.put(word, 1.0);  
  59.                         }  
  60.                     }  
  61.                 }  
  62.                 //遍历一下当前文档的TFmap,除以文档的总词数换成词频,然后将词频乘以词的IDF,得到最终的特征权值,并且输出到文件  
  63.                 //注意测试样例和训练样例写入的文件不同  
  64.                 if(j >= testBeginIndex && j <= testEndIndex){  
  65.                     tsWriter = tsTestWrtier;  
  66.                 }  
  67.                 else{  
  68.                     tsWriter = tsTrainWriter;  
  69.                 }  
  70.                 Double wordWeight;  
  71.                 Set<Map.Entry<String, Double>> tempTF = TFPerDocMap.entrySet();  
  72.                 for(Iterator<Map.Entry<String, Double>> mt = tempTF.iterator(); mt.hasNext();){  
  73.                     Map.Entry<String, Double> me = mt.next();  
  74.                     //wordWeight =  (me.getValue() / wordSumPerDoc) * IDFPerWordMap.get(me.getKey());  
  75.                     //这里IDF暂时设为1,具体的计算IDF算法改进和实现见我的博客中关于kmeans聚类的博文  
  76.                     wordWeight =  (me.getValue() / wordSumPerDoc) * 1.0;  
  77.                     TFPerDocMap.put(me.getKey(), wordWeight);  
  78.                 }  
  79.                 tsWriter.append(cateShortName + " ");  
  80.                 String keyWord = fileShortName.substring(0,5);  
  81.                 tsWriter.append(keyWord+ " ");  
  82.                 Set<Map.Entry<String, Double>> tempTF2 = TFPerDocMap.entrySet();  
  83.                 for(Iterator<Map.Entry<String, Double>> mt = tempTF2.iterator(); mt.hasNext();){  
  84.                     Map.Entry<String, Double> ne = mt.next();  
  85.                     tsWriter.append(ne.getKey() + " " + ne.getValue() + " ");  
  86.                 }  
  87.                 tsWriter.append("\n");    
  88.                 tsWriter.flush();  
  89.             }  
  90.         }  
  91.         tsTrainWriter.close();  
  92.         tsTestWrtier.close();  
  93.         tsWriter.close();  
  94.     }  
  95.       
  96.     /**统计每个词的总的出现次数,返回出现次数大于3次的词汇构成最终的属性词典 
  97.      * @param strDir 处理好的newsgroup文件目录的绝对路径 
  98.      * @throws IOException  
  99.      */  
  100.     public SortedMap<String,Double> countWords(String strDir,Map<String, Double> wordMap) throws IOException{  
  101.         File sampleFile = new File(strDir);  
  102.         File [] sample = sampleFile.listFiles();  
  103.         String word;  
  104.         for(int i = 0; i < sample.length; i++){  
  105.             if(!sample[i].isDirectory()){  
  106.                 if(sample[i].getName().contains("stemed")){  
  107.                     FileReader samReader = new FileReader(sample[i]);  
  108.                     BufferedReader samBR = new BufferedReader(samReader);  
  109.                     while((word = samBR.readLine()) != null){  
  110.                         if(!word.isEmpty() && wordMap.containsKey(word)){  
  111.                             double count = wordMap.get(word) + 1;  
  112.                             wordMap.put(word, count);  
  113.                         }  
  114.                         else {  
  115.                             wordMap.put(word, 1.0);  
  116.                         }  
  117.                     }  
  118.                 }     
  119.             }  
  120.             else countWords(sample[i].getCanonicalPath(),wordMap);  
  121.         }  
  122.         //只返回出现次数大于3的单词  
  123.         SortedMap<String,Double> newWordMap = new TreeMap<String,Double>();  
  124.         Set<Map.Entry<String,Double>> allWords = wordMap.entrySet();  
  125.         for(Iterator<Map.Entry<String,Double>> it = allWords.iterator(); it.hasNext();){  
  126.             Map.Entry<String, Double> me = it.next();  
  127.             if(me.getValue() >= 1){  
  128.                 newWordMap.put(me.getKey(),me.getValue());  
  129.             }  
  130.         }  
  131.         return newWordMap;    
  132.     }  
  133.       
  134.     /**打印属性词典 
  135.      * @param SortedMap<String,Double> 属性词典 
  136.      * @throws IOException  
  137.      */  
  138.     void printWordMap(Map<String, Double> wordMap) throws IOException {  
  139.         // TODO Auto-generated method stub  
  140.         System.out.println("printWordMap");  
  141.         int countLine = 0;  
  142.         File outPutFile = new File("F:/DataMiningSample/docVector/allDicWordCountMap.txt");  
  143.         FileWriter outPutFileWriter = new FileWriter(outPutFile);  
  144.         Set<Map.Entry<String,Double>> allWords = wordMap.entrySet();  
  145.         for(Iterator<Map.Entry<String,Double>> it = allWords.iterator(); it.hasNext();){  
  146.             Map.Entry<String, Double> me = it.next();  
  147.             outPutFileWriter.write(me.getKey()+" "+me.getValue()+"\n");  
  148.             countLine++;  
  149.         }  
  150.         System.out.println("WordMap size" + countLine);  
  151.     }  
  152.       
  153.     /**计算IDF,即属性词典中每个词在多少个文档中出现过 
  154.      * @param SortedMap<String,Double> 属性词典 
  155.      * @return 单词的IDFmap 
  156.      * @throws IOException  
  157.      */  
  158.     SortedMap<String,Double> computeIDF(String string, Map<String, Double> wordMap) throws IOException {  
  159.         // TODO Auto-generated method stub  
  160.         File fileDir = new File(string);  
  161.         String word;  
  162.         SortedMap<String,Double> IDFPerWordMap = new TreeMap<String,Double>();    
  163.         Set<Map.Entry<String, Double>> wordMapSet = wordMap.entrySet();  
  164.         for(Iterator<Map.Entry<String, Double>> pt = wordMapSet.iterator(); pt.hasNext();){  
  165.             Map.Entry<String, Double> pe = pt.next();  
  166.             Double coutDoc = 0.0;  
  167.             String dicWord = pe.getKey();  
  168.             File[] sampleDir = fileDir.listFiles();  
  169.             for(int i = 0; i < sampleDir.length; i++){  
  170.                 File[] sample = sampleDir[i].listFiles();  
  171.                 for(int j = 0;j < sample.length; j++){  
  172.                     FileReader samReader = new FileReader(sample[j]);  
  173.                     BufferedReader samBR = new BufferedReader(samReader);  
  174.                     boolean isExited = false;  
  175.                     while((word = samBR.readLine()) != null){  
  176.                         if(!word.isEmpty() && word.equals(dicWord)){  
  177.                             isExited = true;  
  178.                             break;  
  179.                         }  
  180.                     }  
  181.                     if(isExited) coutDoc++;   
  182.                     }     
  183.                 }  
  184.             //计算单词的IDF  
  185.             Double IDF = Math.log(20000 / coutDoc) / Math.log(10);  
  186.             IDFPerWordMap.put(dicWord, IDF);  
  187.             }  
  188.         return IDFPerWordMap;  
  189.     }  
  190. }  

3 KNN算法的实现

KNN算法的实现要注意

(1)用TreeMap<String,TreeMap<String,Double>>保存测试集和训练集
(2)注意要以"类目_文件名"作为每个文件的key,才能避免同名不同内容的文件出现
(3)注意设置JM参数,否则会出现JAVA heap溢出错误
(4)本程序用向量夹角余弦计算相似度

KNN算法实现类 KNNClassifier.java如下

[java]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. package com.pku.yangliu;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.FileWriter;  
  7. import java.io.IOException;  
  8. import java.util.Comparator;  
  9. import java.util.HashMap;  
  10. import java.util.Iterator;  
  11. import java.util.Map;  
  12. import java.util.Set;  
  13. import java.util.TreeMap;  
  14.   
  15. /**KNN算法的实现类,本程序用向量夹角余弦计算相似度 
  16.  * 
  17.  */  
  18.   
  19. public class KNNClassifier {  
  20.       
  21.     /**用KNN算法对测试文档集分类,读取测试样例和训练样例集 
  22.      * @param trainFiles 训练样例的所有向量构成的文件 
  23.      * @param testFiles 测试样例的所有向量构成的文件 
  24.      * @param kNNResultFile KNN分类结果文件路径 
  25.      * @return double 分类准确率 
  26.      * @throws IOException  
  27.      */  
  28.     private double doProcess(String trainFiles, String testFiles,  
  29.             String kNNResultFile) throws IOException {  
  30.         // TODO Auto-generated method stub  
  31.         //首先读取训练样本和测试样本,用map<String,map<word,TF>>保存测试集和训练集,注意训练样本的类目信息也得保存,  
  32.         //然后遍历测试样本,对于每一个测试样本去计算它与所有训练样本的相似度,相似度保存入map<String,double>有  
  33.         //序map中去,然后取前K个样本,针对这k个样本来给它们所属的类目计算权重得分,对属于同一个类目的权重求和进而得到  
  34.         //最大得分的类目,就可以判断测试样例属于该类目下,K值可以反复测试,找到分类准确率最高的那个值  
  35.         //!注意要以"类目_文件名"作为每个文件的key,才能避免同名不同内容的文件出现  
  36.         //!注意设置JM参数,否则会出现JAVA heap溢出错误  
  37.         //!本程序用向量夹角余弦计算相似度  
  38.         File trainSamples = new File(trainFiles);  
  39.         BufferedReader trainSamplesBR = new BufferedReader(new FileReader(trainSamples));  
  40.         String line;  
  41.         String [] lineSplitBlock;  
  42.         Map<String,TreeMap<String,Double>> trainFileNameWordTFMap = new TreeMap<String,TreeMap<String,Double>> ();  
  43.         TreeMap<String,Double> trainWordTFMap = new TreeMap<String,Double>();  
  44.         while((line = trainSamplesBR.readLine()) != null){  
  45.             lineSplitBlock = line.split(" ");  
  46.             trainWordTFMap.clear();  
  47.             for(int i = 2; i < lineSplitBlock.length; i = i + 2){  
  48.                 trainWordTFMap.put(lineSplitBlock[i], Double.valueOf(lineSplitBlock[i+1]));  
  49.             }  
  50.             TreeMap<String,Double> tempMap = new TreeMap<String,Double>();  
  51.             tempMap.putAll(trainWordTFMap);  
  52.             trainFileNameWordTFMap.put(lineSplitBlock[0]+"_"+lineSplitBlock[1], tempMap);  
  53.         }  
  54.         trainSamplesBR.close();  
  55.           
  56.         File testSamples = new File(testFiles);  
  57.         BufferedReader testSamplesBR = new BufferedReader(new FileReader(testSamples));  
  58.         Map<String,Map<String,Double>> testFileNameWordTFMap = new TreeMap<String,Map<String,Double>> ();  
  59.         Map<String,String> testClassifyCateMap = new TreeMap<String, String>();//分类形成的<文件名,类目>对  
  60.         Map<String,Double> testWordTFMap = new TreeMap<String,Double>();  
  61.         while((line = testSamplesBR.readLine()) != null){  
  62.             lineSplitBlock = line.split(" ");  
  63.             testWordTFMap.clear();  
  64.             for(int i = 2; i < lineSplitBlock.length; i = i + 2){  
  65.                 testWordTFMap.put(lineSplitBlock[i], Double.valueOf(lineSplitBlock[i+1]));  
  66.             }  
  67.             TreeMap<String,Double> tempMap = new TreeMap<String,Double>();  
  68.             tempMap.putAll(testWordTFMap);  
  69.             testFileNameWordTFMap.put(lineSplitBlock[0]+"_"+lineSplitBlock[1], tempMap);  
  70.         }  
  71.         testSamplesBR.close();  
  72.         //下面遍历每一个测试样例计算与所有训练样本的距离,做分类  
  73.         String classifyResult;  
  74.         FileWriter testYangliuWriter = new FileWriter(new File("F:/DataMiningSample/docVector/yangliuTest"));  
  75.         FileWriter KNNClassifyResWriter = new FileWriter(kNNResultFile);  
  76.         Set<Map.Entry<String,Map<String,Double>>> testFileNameWordTFMapSet = testFileNameWordTFMap.entrySet();  
  77.         for(Iterator<Map.Entry<String,Map<String,Double>>> it = testFileNameWordTFMapSet.iterator(); it.hasNext();){  
  78.             Map.Entry<String, Map<String,Double>> me = it.next();  
  79.             classifyResult = KNNComputeCate(me.getKey(), me.getValue(), trainFileNameWordTFMap, testYangliuWriter);  
  80.             KNNClassifyResWriter.append(me.getKey()+" "+classifyResult+"\n");  
  81.             KNNClassifyResWriter.flush();  
  82.             testClassifyCateMap.put(me.getKey(), classifyResult);  
  83.         }  
  84.         KNNClassifyResWriter.close();  
  85.         //计算分类的准确率  
  86.         double righteCount = 0;  
  87.         Set<Map.Entry<String, String>> testClassifyCateMapSet = testClassifyCateMap.entrySet();  
  88.         for(Iterator <Map.Entry<String, String>> it = testClassifyCateMapSet.iterator(); it.hasNext();){  
  89.             Map.Entry<String, String> me = it.next();  
  90.             String rightCate = me.getKey().split("_")[0];  
  91.             if(me.getValue().equals(rightCate)){  
  92.                 righteCount++;  
  93.             }  
  94.         }     
  95.         testYangliuWriter.close();  
  96.         return righteCount / testClassifyCateMap.size();  
  97.     }  
  98.       
  99.     /**对于每一个测试样本去计算它与所有训练样本的向量夹角余弦相似度 
  100.      * 相似度保存入map<String,double>有序map中去,然后取前K个样本, 
  101.      * 针对这k个样本来给它们所属的类目计算权重得分,对属于同一个类 
  102.      * 目的权重求和进而得到最大得分的类目,就可以判断测试样例属于该 
  103.      * 类目下。K值可以反复测试,找到分类准确率最高的那个值 
  104.      * @param testWordTFMap 当前测试文件的<单词,词频>向量 
  105.      * @param trainFileNameWordTFMap 训练样本<类目_文件名,向量>Map 
  106.      * @param testYangliuWriter  
  107.      * @return String K个邻居权重得分最大的类目 
  108.      * @throws IOException  
  109.      */  
  110.     private String KNNComputeCate(  
  111.             String testFileName,  
  112.             Map<String, Double> testWordTFMap,  
  113.             Map<String, TreeMap<String, Double>> trainFileNameWordTFMap, FileWriter testYangliuWriter) throws IOException {  
  114.         // TODO Auto-generated method stub  
  115.         HashMap<String,Double> simMap = new HashMap<String,Double>();//<类目_文件名,距离> 后面需要将该HashMap按照value排序  
  116.         double similarity;  
  117.         Set<Map.Entry<String,TreeMap<String,Double>>> trainFileNameWordTFMapSet = trainFileNameWordTFMap.entrySet();  
  118.         for(Iterator<Map.Entry<String,TreeMap<String,Double>>> it = trainFileNameWordTFMapSet.iterator(); it.hasNext();){  
  119.             Map.Entry<String, TreeMap<String,Double>> me = it.next();  
  120.             similarity = computeSim(testWordTFMap, me.getValue());  
  121.             simMap.put(me.getKey(),similarity);  
  122.         }  
  123.         //下面对simMap按照value排序  
  124.         ByValueComparator bvc = new ByValueComparator(simMap);  
  125.         TreeMap<String,Double> sortedSimMap = new TreeMap<String,Double>(bvc);  
  126.         sortedSimMap.putAll(simMap);  
  127.           
  128.         //在disMap中取前K个最近的训练样本对其类别计算距离之和,K的值通过反复试验而得  
  129.         Map<String,Double> cateSimMap = new TreeMap<String,Double>();//K个最近训练样本所属类目的距离之和  
  130.         double K = 20;  
  131.         double count = 0;  
  132.         double tempSim;  
  133.           
  134.         Set<Map.Entry<String, Double>> simMapSet = sortedSimMap.entrySet();  
  135.         for(Iterator<Map.Entry<String, Double>> it = simMapSet.iterator(); it.hasNext();){  
  136.             Map.Entry<String, Double> me = it.next();  
  137.             count++;  
  138.             String categoryName = me.getKey().split("_")[0];  
  139.             if(cateSimMap.containsKey(categoryName)){  
  140.                 tempSim = cateSimMap.get(categoryName);  
  141.                 cateSimMap.put(categoryName, tempSim + me.getValue());  
  142.             }  
  143.             else cateSimMap.put(categoryName, me.getValue());  
  144.             if (count > K) break;  
  145.         }  
  146.         //下面到cateSimMap里面把sim最大的那个类目名称找出来  
  147.         //testYangliuWriter.flush();  
  148.         //testYangliuWriter.close();  
  149.         double maxSim = 0;  
  150.         String bestCate = null;  
  151.         Set<Map.Entry<String, Double>> cateSimMapSet = cateSimMap.entrySet();  
  152.         for(Iterator<Map.Entry<String, Double>> it = cateSimMapSet.iterator(); it.hasNext();){  
  153.             Map.Entry<String, Double> me = it.next();  
  154.             if(me.getValue()> maxSim){  
  155.                 bestCate = me.getKey();  
  156.                 maxSim = me.getValue();  
  157.             }  
  158.         }  
  159.         return bestCate;  
  160.     }  
  161.   
  162.     /**计算测试样本向量和训练样本向量的相似度 
  163.      * @param testWordTFMap 当前测试文件的<单词,词频>向量 
  164.      * @param trainWordTFMap 当前训练样本<单词,词频>向量 
  165.      * @return Double 向量之间的相似度 以向量夹角余弦计算 
  166.      * @throws IOException  
  167.      */  
  168.     private double computeSim(Map<String, Double> testWordTFMap,  
  169.             Map<String, Double> trainWordTFMap) {  
  170.         // TODO Auto-generated method stub  
  171.         double mul = 0, testAbs = 0, trainAbs = 0;  
  172.         Set<Map.Entry<String, Double>> testWordTFMapSet = testWordTFMap.entrySet();  
  173.         for(Iterator<Map.Entry<String, Double>> it = testWordTFMapSet.iterator(); it.hasNext();){  
  174.             Map.Entry<String, Double> me = it.next();  
  175.             if(trainWordTFMap.containsKey(me.getKey())){  
  176.                 mul += me.getValue()*trainWordTFMap.get(me.getKey());  
  177.             }  
  178.             testAbs += me.getValue() * me.getValue();  
  179.         }  
  180.         testAbs = Math.sqrt(testAbs);  
  181.           
  182.         Set<Map.Entry<String, Double>> trainWordTFMapSet = trainWordTFMap.entrySet();  
  183.         for(Iterator<Map.Entry<String, Double>> it = trainWordTFMapSet.iterator(); it.hasNext();){  
  184.             Map.Entry<String, Double> me = it.next();  
  185.             trainAbs += me.getValue()*me.getValue();  
  186.         }  
  187.         trainAbs = Math.sqrt(trainAbs);  
  188.         return mul / (testAbs * trainAbs);  
  189.     }  
  190.   
  191.     /**根据KNN算法分类结果文件生成正确类目文件,而正确率和混淆矩阵的计算可以复用贝叶斯算法类中的方法 
  192.      * @param kNNRightFile 分类正确类目文件 
  193.      * @param kNNResultFile 分类结果文件 
  194.      * @throws IOException  
  195.      */  
  196.     private void createRightFile(String kNNResultFile, String kNNRightFile) throws IOException {  
  197.         // TODO Auto-generated method stub  
  198.         String rightCate;  
  199.         FileReader fileR = new FileReader(kNNResultFile);  
  200.         FileWriter KNNRrightResult = new FileWriter(new File(kNNRightFile));  
  201.         BufferedReader fileBR = new BufferedReader(fileR);  
  202.         String line;  
  203.         String lineBlock[];  
  204.         while((line = fileBR.readLine()) != null){  
  205.             lineBlock = line.split(" ");  
  206.             rightCate = lineBlock[0].split("_")[0];  
  207.             KNNRrightResult.append(lineBlock[0]+" "+rightCate+"\n");  
  208.         }  
  209.         KNNRrightResult.flush();  
  210.         KNNRrightResult.close();  
  211.     }  
  212.           
  213.       
  214.     /** 
  215.      * @param args 
  216.      * @throws IOException  
  217.      */  
  218.     public void KNNClassifierMain(String[] args) throws IOException {  
  219.         // TODO Auto-generated method stub  
  220.         //wordMap是所有属性词的词典<单词,在所有文档中出现的次数>  
  221.         double[] accuracyOfEveryExp = new double[10];  
  222.         double accuracyAvg,sum = 0;  
  223.         KNNClassifier knnClassifier = new KNNClassifier();  
  224.         NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();  
  225.         Map<String,Double> wordMap = new TreeMap<String,Double>();  
  226.         Map<String,Double> IDFPerWordMap = new TreeMap<String,Double>();      
  227.         ComputeWordsVector computeWV = new ComputeWordsVector();  
  228.         wordMap = computeWV.countWords("F:/DataMiningSample/processedSample_includeNotSpecial", wordMap);  
  229. IDFPerWordMap = computeWV.computeIDF("F:/DataMiningSample/processedSampleOnlySpecial",wordMap);  
  230.         computeWV.printWordMap(wordMap);  
  231.         //首先生成KNN算法10次试验需要的文档TF矩阵文件  
  232.         for(int i = 0; i < 10; i++){  
  233.             computeWV.computeTFMultiIDF("F:/DataMiningSample/processedSampleOnlySpecial",0.9, i, IDFPerWordMap,wordMap);  
  234.             String trainFiles = "F:/DataMiningSample/docVector/wordTFIDFMapTrainSample"+i;  
  235.             String testFiles = "F:/DataMiningSample/docVector/wordTFIDFMapTestSample"+i;  
  236.             String kNNResultFile = "F:/DataMiningSample/docVector/KNNClassifyResult"+i;  
  237.             String kNNRightFile = "F:/DataMiningSample/docVector/KNNClassifyRight"+i;  
  238.             accuracyOfEveryExp[i] = knnClassifier.doProcess(trainFiles, testFiles, kNNResultFile);  
  239.             knnClassifier.createRightFile(kNNResultFile,kNNRightFile);  
  240.             accuracyOfEveryExp[i] = nbClassifier.computeAccuracy(kNNResultFile, kNNRightFile);//计算准确率复用贝叶斯算法中的方法  
  241.             sum += accuracyOfEveryExp[i];  
  242.             System.out.println("The accuracy for KNN Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);  
  243.         }  
  244.         accuracyAvg = sum / 10;  
  245.         System.out.println("The average accuracy for KNN Classifier in all Exps is :" + accuracyAvg);  
  246.     }  
  247.       
  248.     //对HashMap按照value做排序  
  249.     static class ByValueComparator implements Comparator<Object> {  
  250.         HashMap<String, Double> base_map;  
  251.   
  252.         public ByValueComparator(HashMap<String, Double> disMap) {  
  253.             this.base_map = disMap;  
  254.         }  
  255.           
  256.         @Override  
  257.         public int compare(Object o1, Object o2) {  
  258.             // TODO Auto-generated method stub  
  259.             String arg0 = o1.toString();  
  260.             String arg1 = o2.toString();  
  261.             if (!base_map.containsKey(arg0) || !base_map.containsKey(arg1)) {  
  262.                 return 0;  
  263.             }  
  264.             if (base_map.get(arg0) < base_map.get(arg1)) {  
  265.                 return 1;  
  266.             } else if (base_map.get(arg0) == base_map.get(arg1)) {  
  267.                 return 0;  
  268.             } else {  
  269.                 return -1;  
  270.             }  
  271.         }  
  272.     }  
  273. }  

分类器主类

[java]  view plain copy 在CODE上查看代码片 派生到我的代码片
  1. package com.pku.yangliu;  
  2.   
  3. /**分类器主分类,依次执行数据预处理、朴素贝叶斯分类、KNN分类 
  4.  * 
  5.  */  
  6. public class ClassifierMain {  
  7.   
  8.     public static void main(String[] args) throws Exception {  
  9.         // TODO Auto-generated method stub  
  10.         DataPreProcess DataPP = new DataPreProcess();  
  11.         NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();  
  12.         KNNClassifier knnClassifier = new KNNClassifier();  
  13.         //DataPP.BPPMain(args);  
  14.         nbClassifier.NaiveBayesianClassifierMain(args);  
  15.         knnClassifier.KNNClassifierMain(args);  
  16.     }  
  17. }  


5 KNN算法的分类结果

用混淆矩阵表示如下,第6次实验准确率达到82.10%

程序运行环境硬件环境:Intel Core 2 Duo CPU T5750 2GHZ, 2G内存,相同硬件环境计算和贝叶斯算法做对比

实验结果如上所示 取出现次数大于等于4次的词共计30095个作为特征词: 10次交叉验证实验平均准确率78.19%,用时1h55min,10词实验准确率范围73.62%-82.10%,其中有3次实验准确率超过80%

6 朴素贝叶斯与KNN分类准确率对比
取出现次数大于等于4次的词共计30095个作为特征词,做10次交叉验证实验,朴素贝叶斯和KNN算法对Newsgroup文档分类结果对比:

点击打开链接

结论
分类准确率上,KNN算法更优
分类速度上,朴素贝叶斯算法更优

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值