数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文本分类器的JAVA实现(上)

(update 2012.12.28 关于本项目下载及运行的常见问题 FAQ见 newsgroup18828文本分类器、文本聚类器、关联分析频繁模式挖掘算法的Java实现工程下载及运行FAQ )

本文主要内容如下:
对newsgroup文档集进行预处理,提取出30095 个特征词

计算每篇文档中的特征词的TF*IDF值,实现文档向量化,在KNN算法中使用

Java实现了KNN算法及朴素贝叶斯算法的newsgroup文本分类器

1、Newsgroup文档集介绍

Newsgroups最早由Lang于1995收集并在[Lang 1995]中使用。它含有20000篇左右的Usenet文档,几乎平均分配20个不同的新闻组。除了其中4.5%的文档属于两个或两个以上的新闻组以外,其余文档仅属于一个新闻组,因此它通常被作为单标注分类问题来处理。Newsgroups已经成为文本分及聚类中常用的文档集。美国MIT大学Jason Rennie对Newsgroups作了必要的处理,使得每个文档只属于一个新闻组,形成Newsgroups-18828。

2、Newsgroup文档预处理

要做文本分类首先得完成文本的预处理,预处理的主要步骤如下

STEP ONE:英文词法分析,去除数字、连字符、标点符号、特殊 字符,所有大写字母转换成小写,可以用正则表达式
                     String res[] = line.split("[^a-zA-Z]");
STEP TWO:去停用词,过滤对分类无价值的词
STEP THRE: 词根还原stemming,基于Porter算法
文档预处理类 DataPreProcess.java如下
[java]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. package com.pku.yangliu;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.FileWriter;  
  7. import java.io.IOException;  
  8. import java.util.ArrayList;  
  9.   
  10. /**  
  11.  * Newsgroups文档集预处理类 
  12.  */  
  13. public class DataPreProcess {  
  14.       
  15.     /**输入文件调用处理数据函数 
  16.      * @param strDir newsgroup文件目录的绝对路径 
  17.      * @throws IOException  
  18.      */  
  19.     public void doProcess(String strDir) throws IOException{  
  20.         File fileDir = new File(strDir);  
  21.         if(!fileDir.exists()){  
  22.             System.out.println("File not exist:" + strDir);  
  23.             return;  
  24.         }  
  25.         String subStrDir = strDir.substring(strDir.lastIndexOf('/'));  
  26.         String dirTarget = strDir + "/../../processedSample_includeNotSpecial"+subStrDir;  
  27.         File fileTarget = new File(dirTarget);  
  28.         if(!fileTarget.exists()){//注意processedSample需要先建立目录建出来,否则会报错,因为母目录不存在  
  29.             fileTarget.mkdir();  
  30.         }  
  31.         File[] srcFiles = fileDir.listFiles();  
  32.         String[] stemFileNames = new String[srcFiles.length];  
  33.         for(int i = 0; i < srcFiles.length; i++){  
  34.             String fileFullName = srcFiles[i].getCanonicalPath();  
  35.             String fileShortName = srcFiles[i].getName();  
  36.             if(!new File(fileFullName).isDirectory()){//确认子文件名不是目录如果是可以再次递归调用  
  37.                 System.out.println("Begin preprocess:"+fileFullName);  
  38.                 StringBuilder stringBuilder = new StringBuilder();  
  39.                 stringBuilder.append(dirTarget + "/" + fileShortName);  
  40.                 createProcessFile(fileFullName, stringBuilder.toString());  
  41.                 stemFileNames[i] = stringBuilder.toString();  
  42.             }  
  43.             else {  
  44.                 fileFullName = fileFullName.replace("\\","/");  
  45.                 doProcess(fileFullName);  
  46.             }  
  47.         }  
  48.         //下面调用stem算法  
  49.         if(stemFileNames.length > 0 && stemFileNames[0] != null){  
  50.             Stemmer.porterMain(stemFileNames);  
  51.         }  
  52.     }  
  53.       
  54.     /**进行文本预处理生成目标文件 
  55.      * @param srcDir 源文件文件目录的绝对路径 
  56.      * @param targetDir 生成的目标文件的绝对路径 
  57.      * @throws IOException  
  58.      */  
  59.     private static void createProcessFile(String srcDir, String targetDir) throws IOException {  
  60.         // TODO Auto-generated method stub  
  61.         FileReader srcFileReader = new FileReader(srcDir);  
  62.         FileReader stopWordsReader = new FileReader("F:/DataMiningSample/stopwords.txt");  
  63.         FileWriter targetFileWriter = new FileWriter(targetDir);      
  64.         BufferedReader srcFileBR = new BufferedReader(srcFileReader);//装饰模式  
  65.         BufferedReader stopWordsBR = new BufferedReader(stopWordsReader);  
  66.         String line, resLine, stopWordsLine;  
  67.         //用stopWordsBR够着停用词的ArrayList容器  
  68.         ArrayList<String> stopWordsArray = new ArrayList<String>();  
  69.         while((stopWordsLine = stopWordsBR.readLine()) != null){  
  70.             if(!stopWordsLine.isEmpty()){  
  71.                 stopWordsArray.add(stopWordsLine);  
  72.             }  
  73.         }  
  74.         while((line = srcFileBR.readLine()) != null){  
  75.             resLine = lineProcess(line,stopWordsArray);  
  76.             if(!resLine.isEmpty()){  
  77.                 //按行写,一行写一个单词  
  78.                 String[] tempStr = resLine.split(" ");//\s  
  79.                 for(int i = 0; i < tempStr.length; i++){  
  80.                     if(!tempStr[i].isEmpty()){  
  81.                         targetFileWriter.append(tempStr[i]+"\n");  
  82.                     }  
  83.                 }  
  84.             }  
  85.         }  
  86.         targetFileWriter.flush();  
  87.         targetFileWriter.close();  
  88.         srcFileReader.close();  
  89.         stopWordsReader.close();  
  90.         srcFileBR.close();  
  91.         stopWordsBR.close();      
  92.     }  
  93.       
  94.     /**对每行字符串进行处理,主要是词法分析、去停用词和stemming 
  95.      * @param line 待处理的一行字符串 
  96.      * @param ArrayList<String> 停用词数组 
  97.      * @return String 处理好的一行字符串,是由处理好的单词重新生成,以空格为分隔符 
  98.      * @throws IOException  
  99.      */  
  100.     private static String lineProcess(String line, ArrayList<String> stopWordsArray) throws IOException {  
  101.         // TODO Auto-generated method stub  
  102.         //step1 英文词法分析,去除数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可以考虑用正则表达式  
  103.         String res[] = line.split("[^a-zA-Z]");  
  104.         //这里要小心,防止把有单词中间有数字和连字符的单词 截断了,但是截断也没事  
  105.           
  106.         String resString = new String();  
  107.         //step2去停用词  
  108.         //step3stemming,返回后一起做  
  109.         for(int i = 0; i < res.length; i++){  
  110.             if(!res[i].isEmpty() && !stopWordsArray.contains(res[i].toLowerCase())){  
  111.                 resString += " " + res[i].toLowerCase() + " ";  
  112.             }  
  113.         }  
  114.         return resString;  
  115.     }  
  116.   
  117.     /** 
  118.      * @param args 
  119.      * @throws IOException  
  120.      */  
  121.     public void BPPMain(String[] args) throws IOException {  
  122.         // TODO Auto-generated method stub  
  123.         DataPreProcess dataPrePro = new DataPreProcess();  
  124.         dataPrePro.doProcess("F:/DataMiningSample/orginSample");  
  125.   
  126.     }  
  127.   
  128. }  
steming的porter算法可以Google,有C及JAVA的实现版本,点击下载 porter算法JAVA版本

2、特征词的选取
首先统计经过预处理后在所有文档中出现不重复的单词一共有87554个,对这些词进行统计发现:
出现次数大于等于1次的词有87554个
出现次数大于等于3次的词有36456个 
出现次数大于等于4次的词有30095个
特征词的选取策略:
策略一:保留所有词作为特征词 共计87554个
策略二:选取出现次数大于等于4次的词作为特征词共计30095个 
特征词的选取策略:采用策略一,后面将对两种特征词选取策略的计算时间和平均准确率做对比
测试集与训练集的创建类CreateTrainAndTestSample.java如下
[java]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. package com.pku.yangliu;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.FileWriter;  
  7. import java.io.IOException;  
  8. import java.util.SortedMap;  
  9. import java.util.TreeMap;  
  10.   
  11. /**创建训练样例集合与测试样例集合 
  12.  * 
  13.  */  
  14. public class CreateTrainAndTestSample {  
  15.       
  16.     void filterSpecialWords() throws IOException {  
  17.         // TODO Auto-generated method stub  
  18.         String word;  
  19.         ComputeWordsVector cwv = new ComputeWordsVector();  
  20.         String fileDir = "F:/DataMiningSample/processedSample_includeNotSpecial";  
  21.         SortedMap<String,Double> wordMap = new TreeMap<String,Double>();  
  22.         wordMap = cwv.countWords(fileDir, wordMap);  
  23.         cwv.printWordMap(wordMap);//把wordMap输出到文件  
  24.         File[] sampleDir = new File(fileDir).listFiles();  
  25.         for(int i = 0; i < sampleDir.length; i++){  
  26.             File[] sample = sampleDir[i].listFiles();  
  27.             String targetDir = "F:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName();  
  28.             File targetDirFile = new File(targetDir);  
  29.             if(!targetDirFile.exists()){  
  30.                 targetDirFile.mkdir();  
  31.             }  
  32.             for(int j = 0;j < sample.length; j++){     
  33.                 String fileShortName = sample[j].getName();  
  34.                 if(fileShortName.contains("stemed")){  
  35.                     targetDir = "F:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName.substring(0,5);  
  36.                     FileWriter tgWriter= new FileWriter(targetDir);  
  37.                     FileReader samReader = new FileReader(sample[j]);  
  38.                     BufferedReader samBR = new BufferedReader(samReader);  
  39.                     while((word = samBR.readLine()) != null){  
  40.                         if(wordMap.containsKey(word)){  
  41.                             tgWriter.append(word + "\n");  
  42.                         }  
  43.                     }  
  44.                     tgWriter.flush();  
  45.                     tgWriter.close();  
  46.                 }  
  47.             }  
  48.         }  
  49.     }  
  50.       
  51.     void createTestSamples(String fileDir, double trainSamplePercent,int indexOfSample,String classifyResultFile) throws IOException {  
  52.         // TODO Auto-generated method stub  
  53.         String word, targetDir;  
  54.         FileWriter crWriter = new FileWriter(classifyResultFile);//测试样例正确类目记录文件  
  55.         File[] sampleDir = new File(fileDir).listFiles();  
  56.         for(int i = 0; i < sampleDir.length; i++){  
  57.             File[] sample = sampleDir[i].listFiles();  
  58.             double testBeginIndex = indexOfSample*(sample.length * (1-trainSamplePercent));//测试样例的起始文件序号  
  59.             double testEndIndex = (indexOfSample+1)*(sample.length * (1-trainSamplePercent));//测试样例集的结束文件序号  
  60.             for(int j = 0;j < sample.length; j++){                 
  61.                 FileReader samReader = new FileReader(sample[j]);  
  62.                 BufferedReader samBR = new BufferedReader(samReader);  
  63.                 String fileShortName = sample[j].getName();  
  64.                 String subFileName = fileShortName;  
  65.                 if(j > testBeginIndex && j< testEndIndex){//序号在规定区间内的作为测试样本,需要为测试样本生成类别-序号文件,最后加入分类的结果,一行对应一个文件,方便统计准确率  
  66.                     targetDir = "F:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName();  
  67.                     crWriter.append(subFileName + " " + sampleDir[i].getName()+"\n");  
  68.                       
  69.                     }  
  70.                 else{//其余作为训练样本  
  71.                     targetDir = "F:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName();  
  72.                 }  
  73.                 targetDir = targetDir.replace("\\","/");  
  74.                 File trainSamFile = new File(targetDir);  
  75.                 if(!trainSamFile.exists()){  
  76.                     trainSamFile.mkdir();  
  77.                 }  
  78.                 targetDir += "/"+subFileName;  
  79.                 FileWriter tsWriter = new FileWriter(new File(targetDir));  
  80.                 while((word = samBR.readLine()) != null){  
  81.                     tsWriter.append(word + "\n");  
  82.                 }  
  83.                 tsWriter.flush();  
  84.                 tsWriter.close();     
  85.             }  
  86.         }  
  87.         crWriter.flush();  
  88.         crWriter.close();  
  89.     }  
  90. }  

3、贝叶斯算法描述及实现
根据朴素贝叶斯公式,每个测试样例属于某个类别的概率 =  所有测试样例包含特征词类条件概率P(tk|c)之积 * 先验概率P(c)
在具体计算类条件概率和先验概率时,朴素贝叶斯分类器有两种模型
(1)多元分布模型( multinomial model )  –以单词为粒度,也就是说,考虑每个文件里面重复出现多次的单词。注意多项分布其实是从二项分布拓展出来的,如果采用多项分布模型,那么每个单词表示变量就不再是二值变量(出现/不出现),而是每个单词在文件中出现的次数
类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+训练样本中不重复特征词总数)
先验概率P(c)=类c下的单词总数/整个训练样本的单词总数
(2)伯努利模型(Bernoulli model) –以文件为粒度,或者说是采用二项分布模型,伯努利实验即N次独立重复随机实验,只考虑事件发生/不发生,所以每个单词的表示变量是布尔型的
类条件概率P(tk|c)=(类c下包含单词tk的文件数+1)/(类c下文件总数+2)(注意:开始此处错写成了单词,多谢网友提醒后更正)
先验概率P(c)=类c下文件总数/整个训练样本的文件总数
本分类器选用多元分布模型计算,根据《Introduction to Information Retrieval 》,多元分布模型计算准确率更高
贝叶斯算法的实现有以下注意点:
(1) 计算概率用到了BigDecimal类实现任意精度计算
(2) 用交叉验证法做十次分类实验,对准确率取平均值
(3) 根据正确类目文件和分类结果文计算混淆矩阵并且输出
(4) Map<String,Double> cateWordsProb key为“类目_单词”, value为该类目下该单词的出现次数,避免重复计算
贝叶斯算法实现类如下 NaiveBayesianClassifier.java
[java]  view plain  copy
  在CODE上查看代码片 派生到我的代码片
  1. package com.pku.yangliu;  
  2.   
  3. import java.io.BufferedReader;  
  4. import java.io.File;  
  5. import java.io.FileReader;  
  6. import java.io.FileWriter;  
  7. import java.io.IOException;  
  8. import java.math.BigDecimal;  
  9. import java.util.Iterator;  
  10. import java.util.Map;  
  11. import java.util.Set;  
  12. import java.util.SortedSet;  
  13. import java.util.TreeMap;  
  14. import java.util.TreeSet;  
  15. import java.util.Vector;  
  16.   
  17. /**利用朴素贝叶斯算法对newsgroup文档集做分类,采用十组交叉测试取平均值 
  18.  * 采用多项式模型,stanford信息检索导论课件上面言多项式模型比伯努利模型准确度高 
  19.  * 类条件概率P(tk|c)=(类c 下单词tk 在各个文档中出现过的次数之和+1)/(类c下单词总数+|V|) 
  20.  * 
  21.  */  
  22. public class NaiveBayesianClassifier {  
  23.       
  24.     /**用贝叶斯法对测试文档集分类 
  25.      * @param trainDir 训练文档集目录 
  26.      * @param testDir 测试文档集目录 
  27.      * @param classifyResultFileNew 分类结果文件路径 
  28.      * @throws Exception  
  29.      */  
  30.     private void doProcess(String trainDir, String testDir,  
  31.             String classifyResultFileNew) throws Exception {  
  32.         // TODO Auto-generated method stub  
  33.         Map<String,Double> cateWordsNum = new TreeMap<String,Double>();//保存训练集每个类别的总词数  
  34.         Map<String,Double> cateWordsProb = new TreeMap<String,Double>();//保存训练样本每个类别中每个属性词的出现词数  
  35.         cateWordsProb = getCateWordsProb(trainDir);  
  36.         cateWordsNum = getCateWordsNum(trainDir);  
  37.         double totalWordsNum = 0.0;//记录所有训练集的总词数  
  38.         Set<Map.Entry<String,Double>> cateWordsNumSet = cateWordsNum.entrySet();  
  39.         for(Iterator<Map.Entry<String,Double>> it = cateWordsNumSet.iterator(); it.hasNext();){  
  40.             Map.Entry<String, Double> me = it.next();  
  41.             totalWordsNum += me.getValue();  
  42.         }  
  43.         //下面开始读取测试样例做分类  
  44.         Vector<String> testFileWords = new Vector<String>();  
  45.         String word;  
  46.         File[] testDirFiles = new File(testDir).listFiles();  
  47.         FileWriter crWriter = new FileWriter(classifyResultFileNew);  
  48.         for(int i = 0; i < testDirFiles.length; i++){  
  49.             File[] testSample = testDirFiles[i].listFiles();  
  50.             for(int j = 0;j < testSample.length; j++){  
  51.                 testFileWords.clear();  
  52.                 FileReader spReader = new FileReader(testSample[j]);  
  53.                 BufferedReader spBR = new BufferedReader(spReader);  
  54.                 while((word = spBR.readLine()) != null){  
  55.                     testFileWords.add(word);  
  56.                 }  
  57.                 //下面分别计算该测试样例属于二十个类别的概率  
  58.                 File[] trainDirFiles = new File(trainDir).listFiles();  
  59.                 BigDecimal maxP = new BigDecimal(0);  
  60.                 String bestCate = null;  
  61.                 for(int k = 0; k < trainDirFiles.length; k++){  
  62.                     BigDecimal p = computeCateProb(trainDirFiles[k], testFileWords, cateWordsNum, totalWordsNum, cateWordsProb);  
  63.                     if(k == 0){  
  64.                         maxP = p;  
  65.                         bestCate = trainDirFiles[k].getName();  
  66.                         continue;  
  67.                     }  
  68.                     if(p.compareTo(maxP) == 1){  
  69.                         maxP = p;  
  70.                         bestCate = trainDirFiles[k].getName();  
  71.                     }  
  72.                 }  
  73.                 crWriter.append(testSample[j].getName() + " " + bestCate + "\n");  
  74.                 crWriter.flush();  
  75.             }  
  76.         }  
  77.         crWriter.close();  
  78.     }  
  79.       
  80.     /**统计某类训练样本中每个单词的出现次数 
  81.      * @param strDir 训练样本集目录 
  82.      * @return Map<String,Double> cateWordsProb 用"类目_单词"对来索引的map,保存的val就是该类目下该单词的出现次数 
  83.      * @throws IOException  
  84.      */  
  85.     public Map<String,Double> getCateWordsProb(String strDir) throws IOException{  
  86.         Map<String,Double> cateWordsProb = new TreeMap<String,Double>();  
  87.         File sampleFile = new File(strDir);  
  88.         File [] sampleDir = sampleFile.listFiles();  
  89.         String word;  
  90.         for(int i = 0;i < sampleDir.length; i++){  
  91.             File [] sample = sampleDir[i].listFiles();  
  92.             for(int j = 0; j < sample.length; j++){  
  93.                 FileReader samReader = new FileReader(sample[j]);  
  94.                 BufferedReader samBR = new BufferedReader(samReader);  
  95.                 while((word = samBR.readLine()) != null){  
  96.                     String key = sampleDir[i].getName() + "_" + word;  
  97.                     if(cateWordsProb.containsKey(key)){  
  98.                         double count = cateWordsProb.get(key) + 1.0;  
  99.                         cateWordsProb.put(key, count);  
  100.                     }  
  101.                     else {  
  102.                         cateWordsProb.put(key, 1.0);  
  103.                     }  
  104.                 }  
  105.             }  
  106.         }  
  107.         return cateWordsProb;     
  108.     }  
  109.       
  110.     /**计算某一个测试样本属于某个类别的概率 
  111.      * @param Map<String, Double> cateWordsProb 记录每个目录中出现的单词及次数  
  112.      * @param File trainFile 该类别所有的训练样本所在目录 
  113.      * @param Vector<String> testFileWords 该测试样本中的所有词构成的容器 
  114.      * @param double totalWordsNum 记录所有训练样本的单词总数 
  115.      * @param Map<String, Double> cateWordsNum 记录每个类别的单词总数 
  116.      * @return BigDecimal 返回该测试样本在该类别中的概率 
  117.      * @throws Exception  
  118.      * @throws IOException  
  119.      */  
  120.     private BigDecimal computeCateProb(File trainFile, Vector<String> testFileWords, Map<String, Double> cateWordsNum, double totalWordsNum, Map<String, Double> cateWordsProb) throws Exception {  
  121.         // TODO Auto-generated method stub  
  122.         BigDecimal probability = new BigDecimal(1);  
  123.         double wordNumInCate = cateWordsNum.get(trainFile.getName());  
  124.         BigDecimal wordNumInCateBD = new BigDecimal(wordNumInCate);  
  125.         BigDecimal totalWordsNumBD = new BigDecimal(totalWordsNum);  
  126.         for(Iterator<String> it = testFileWords.iterator(); it.hasNext();){  
  127.             String me = it.next();  
  128.             String key = trainFile.getName()+"_"+me;  
  129.             double testFileWordNumInCate;  
  130.             if(cateWordsProb.containsKey(key)){  
  131.                 testFileWordNumInCate = cateWordsProb.get(key);  
  132.             }else testFileWordNumInCate = 0.0;  
  133.             BigDecimal testFileWordNumInCateBD = new BigDecimal(testFileWordNumInCate);  
  134.             BigDecimal xcProb = (testFileWordNumInCateBD.add(new BigDecimal(0.0001))).divide(totalWordsNumBD.add(wordNumInCateBD),10, BigDecimal.ROUND_CEILING);  
  135.             probability = probability.multiply(xcProb);  
  136.         }  
  137.         BigDecimal res = probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10, BigDecimal.ROUND_CEILING));  
  138.         return res;  
  139.     }  
  140.   
  141.     /**获得每个类目下的单词总数 
  142.      * @param trainDir 训练文档集目录 
  143.      * @return Map<String, Double> <目录名,单词总数>的map 
  144.      * @throws IOException  
  145.      */  
  146.     private Map<String, Double> getCateWordsNum(String trainDir) throws IOException {  
  147.         // TODO Auto-generated method stub  
  148.         Map<String,Double> cateWordsNum = new TreeMap<String,Double>();  
  149.         File[] sampleDir = new File(trainDir).listFiles();  
  150.         for(int i = 0; i < sampleDir.length; i++){  
  151.             double count = 0;  
  152.             File[] sample = sampleDir[i].listFiles();  
  153.             for(int j = 0;j < sample.length; j++){  
  154.                 FileReader spReader = new FileReader(sample[j]);  
  155.                 BufferedReader spBR = new BufferedReader(spReader);  
  156.                 while(spBR.readLine() != null){  
  157.                     count++;  
  158.                 }         
  159.             }  
  160.             cateWordsNum.put(sampleDir[i].getName(), count);  
  161.         }  
  162.         return cateWordsNum;  
  163.     }  
  164.       
  165.     /**根据正确类目文件和分类结果文件统计出准确率 
  166.      * @param classifyResultFile 正确类目文件 
  167.      * @param classifyResultFileNew 分类结果文件 
  168.      * @return double 分类的准确率 
  169.      * @throws IOException  
  170.      */  
  171.     double computeAccuracy(String classifyResultFile,  
  172.             String classifyResultFileNew) throws IOException {  
  173.         // TODO Auto-generated method stub  
  174.         Map<String,String> rightCate = new TreeMap<String,String>();  
  175.         Map<String,String> resultCate = new TreeMap<String,String>();  
  176.         rightCate = getMapFromResultFile(classifyResultFile);  
  177.         resultCate = getMapFromResultFile(classifyResultFileNew);  
  178.         Set<Map.Entry<String, String>> resCateSet = resultCate.entrySet();  
  179.         double rightCount = 0.0;  
  180.         for(Iterator<Map.Entry<String, String>> it = resCateSet.iterator(); it.hasNext();){  
  181.             Map.Entry<String, String> me = it.next();  
  182.             if(me.getValue().equals(rightCate.get(me.getKey()))){  
  183.                 rightCount ++;  
  184.             }  
  185.         }  
  186.         computerConfusionMatrix(rightCate,resultCate);  
  187.         return rightCount / resultCate.size();    
  188.     }  
  189.       
  190.     /**根据正确类目文件和分类结果文计算混淆矩阵并且输出 
  191.      * @param rightCate 正确类目对应map 
  192.      * @param resultCate 分类结果对应map 
  193.      * @return double 分类的准确率 
  194.      * @throws IOException  
  195.      */  
  196.     private void computerConfusionMatrix(Map<String, String> rightCate,  
  197.             Map<String, String> resultCate) {  
  198.         // TODO Auto-generated method stub    
  199.         int[][] confusionMatrix = new int[20][20];  
  200.         //首先求出类目对应的数组索引  
  201.         SortedSet<String> cateNames = new TreeSet<String>();  
  202.         Set<Map.Entry<String, String>> rightCateSet = rightCate.entrySet();  
  203.         for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator(); it.hasNext();){  
  204.             Map.Entry<String, String> me = it.next();  
  205.             cateNames.add(me.getValue());  
  206.         }  
  207.         cateNames.add("rec.sport.baseball");//防止数少一个类目  
  208.         String[] cateNamesArray = cateNames.toArray(new String[0]);  
  209.         Map<String,Integer> cateNamesToIndex = new TreeMap<String,Integer>();  
  210.         for(int i = 0; i < cateNamesArray.length; i++){  
  211.             cateNamesToIndex.put(cateNamesArray[i],i);  
  212.         }  
  213.         for(Iterator<Map.Entry<String, String>> it = rightCateSet.iterator(); it.hasNext();){  
  214.             Map.Entry<String, String> me = it.next();  
  215.             confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++;  
  216.         }  
  217.         //输出混淆矩阵  
  218.         double[] hangSum = new double[20];  
  219.         System.out.print("    ");  
  220.         for(int i = 0; i < 20; i++){  
  221.             System.out.print(i + "    ");  
  222.         }  
  223.         System.out.println();  
  224.         for(int i = 0; i < 20; i++){  
  225.             System.out.print(i + "    ");  
  226.             for(int j = 0; j < 20; j++){  
  227.                 System.out.print(confusionMatrix[i][j]+"    ");  
  228.                 hangSum[i] += confusionMatrix[i][j];  
  229.             }  
  230.             System.out.println(confusionMatrix[i][i] / hangSum[i]);  
  231.         }  
  232.         System.out.println();  
  233.     }  
  234.   
  235.     /**从分类结果文件中读取map 
  236.      * @param classifyResultFileNew 类目文件 
  237.      * @return Map<String, String> 由<文件名,类目名>保存的map 
  238.      * @throws IOException  
  239.      */  
  240.     private Map<String, String> getMapFromResultFile(  
  241.             String classifyResultFileNew) throws IOException {  
  242.         // TODO Auto-generated method stub  
  243.         File crFile = new File(classifyResultFileNew);  
  244.         FileReader crReader = new FileReader(crFile);  
  245.         BufferedReader crBR = new BufferedReader(crReader);  
  246.         Map<String, String> res = new TreeMap<String, String>();  
  247.         String[] s;  
  248.         String line;  
  249.         while((line = crBR.readLine()) != null){  
  250.             s = line.split(" ");  
  251.             res.put(s[0], s[1]);      
  252.         }  
  253.         return res;  
  254.     }  
  255.   
  256.     /** 
  257.      * @param args 
  258.      * @throws Exception  
  259.      */  
  260.     public void NaiveBayesianClassifierMain(String[] args) throws Exception {  
  261.          //TODO Auto-generated method stub  
  262.         //首先创建训练集和测试集  
  263.         CreateTrainAndTestSample ctt = new CreateTrainAndTestSample();  
  264.         NaiveBayesianClassifier nbClassifier = new NaiveBayesianClassifier();  
  265.         ctt.filterSpecialWords();//根据包含非特征词的文档集生成只包含特征词的文档集到processedSampleOnlySpecial目录下  
  266.         double[] accuracyOfEveryExp = new double[10];  
  267.         double accuracyAvg,sum = 0;  
  268.         for(int i = 0; i < 10; i++){//用交叉验证法做十次分类实验,对准确率取平均值   
  269.             String TrainDir = "F:/DataMiningSample/TrainSample"+i;  
  270.             String TestDir = "F:/DataMiningSample/TestSample"+i;  
  271.             String classifyRightCate = "F:/DataMiningSample/classifyRightCate"+i+".txt";  
  272.             String classifyResultFileNew = "F:/DataMiningSample/classifyResultNew"+i+".txt";  
  273.             ctt.createTestSamples("F:/DataMiningSample/processedSampleOnlySpecial"0.9, i,classifyRightCate);  
  274.             nbClassifier.doProcess(TrainDir,TestDir,classifyResultFileNew);  
  275.             accuracyOfEveryExp[i] = nbClassifier.computeAccuracy (classifyRightCate, classifyResultFileNew);  
  276.             System.out.println("The accuracy for Naive Bayesian Classifier in "+i+"th Exp is :" + accuracyOfEveryExp[i]);  
  277.         }  
  278.         for(int i = 0; i < 10; i++){  
  279.             sum += accuracyOfEveryExp[i];  
  280.         }  
  281.         accuracyAvg = sum / 10;  
  282.         System.out.println("The average accuracy for Naive Bayesian Classifier in all Exps is :" + accuracyAvg);  
  283.           
  284.     }  
  285. }  

4 朴素贝叶斯算法对newsgroup文档集做分类的结果

为方便计算混淆矩阵,将类目编号如下

0 alt.atheism 
1 comp.graphics 
2 comp.os.ms-windows.misc 
3comp.sys.ibm.pc.hdwar
4comp.sys.mac.hardwar
5 comp.windows.x 
6 misc.forsale 
7 rec.autos 
8 rec.motorcycles 
9 rec.sport.baseball 
10 rec.sport.hockey 
11 sci.crypt 
12 sci.electronics 
13 sci.med 
14 sci.space 
15 soc.religion.christian 
16 talk.politics.guns 
17 talk.politics.mideast 
18 talk.politics.misc 
19 talk.religion.misc

贝叶斯算法分类结果-混淆矩阵表示,以交叉验证的第6次实验结果为例,分类准确率达到80.47%
程序运行硬件环境:Intel Core 2 Duo CPU T5750 2GHZ, 2G内存,实验结果如下
取所有词共87554个作为特征词:10次交叉验证实验平均准确率78.19%,用时23min,准确率范围75.65%-80.47%,第6次实验准确率超过80%
取出现次数大于等于4次的词共计30095个作为特征词: 10次交叉验证实验平均准确率77.91%,用时22min,准确率范围75.51%-80.26%,第6次实验准确率超过80%
结论:朴素贝叶斯算法不必去除出现次数很低的词,因为出现次数很低的词的IDF比较   大,去除后分类准确率下降,而计算时间并没有显著减少
5 贝叶斯算法的改进
为了进一步提高贝叶斯算法的分类准确率,可以考虑
(1) 优化特征词的选取策略
(2)改进多项式模型的类条件概率的计算公式,改进为 类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不重复特征词总数),分子当tk没有出现时,只加0.001,这样更加精确的描述的词的统计分布规律,做此改进后的混淆矩阵如下

可以看到第6次分组实验的准确率提高到84.79%,第7词分组实验的准确率达到85.24%,平均准确率由77.91%提高到了82.23%,优化效果还是很明显的
KNN算法描述及JAVA实现,和两种算法的准确率对比,见数据挖掘- 基于贝叶斯算法及KNN算法的newsgroup18828文档分类器的JAVA实现(下)
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值