数据挖掘-基于贝叶斯算法及KNN算法的newsgroup18828文本分类器的JAVA实现(上)

(update 2012.12.28 关于本项目下载及运行的常见问题 FAQ见newsgroup18828文本分类器、文本聚类器、关联分析频繁模式挖掘算法的Java实现工程下载及运行FAQ)

本文主要内容如下:
对newsgroup文档集进行预处理,提取出30095 个特征词

计算每篇文档中的特征词的TF*IDF值,实现文档向量化,在KNN算法中使用

用JAVA实现了KNN算法及朴素贝叶斯算法的newsgroup文本分类器

1、Newsgroup文档集介绍

Newsgroups最早由Lang于1995收集并在[Lang 1995]中使用。它含有20000篇左右的Usenet文档,几乎平均分配20个不同的新闻组。除了其中4.5%的文档属于两个或两个以上的新闻组以外,其余文档仅属于一个新闻组,因此它通常被作为单标注分类问题来处理。Newsgroups已经成为文本分及聚类中常用的文档集。美国MIT大学Jason Rennie对Newsgroups作了必要的处理,使得每个文档只属于一个新闻组,形成Newsgroups-18828。

2、Newsgroup文档预处理

要做文本分类首先得完成文本的预处理,预处理的主要步骤如下

STEP ONE:英文词法分析,去除数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可以用正则表达式
String res[] = line.split("[^a-zA-Z]");
STEP TWO:去停用词,过滤对分类无价值的词
STEP THRE:词根还原stemming,基于Porter算法
文档预处理类 DataPreProcess.java如下
[java] view plain copy 在CODE上查看代码片 派生到我的代码片
  1. packagecom.pku.yangliu;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.FileWriter;
  6. importjava.io.IOException;
  7. importjava.util.ArrayList;
  8. /**
  9. *Newsgroups文档集预处理类
  10. */
  11. publicclassDataPreProcess{
  12. /**输入文件调用处理数据函数
  13. *@paramstrDirnewsgroup文件目录的绝对路径
  14. *@throwsIOException
  15. */
  16. publicvoiddoProcess(StringstrDir)throwsIOException{
  17. FilefileDir=newFile(strDir);
  18. if(!fileDir.exists()){
  19. System.out.println("Filenotexist:"+strDir);
  20. return;
  21. }
  22. StringsubStrDir=strDir.substring(strDir.lastIndexOf('/'));
  23. StringdirTarget=strDir+"/../../processedSample_includeNotSpecial"+subStrDir;
  24. FilefileTarget=newFile(dirTarget);
  25. if(!fileTarget.exists()){//注意processedSample需要先建立目录建出来,否则会报错,因为母目录不存在
  26. fileTarget.mkdir();
  27. }
  28. File[]srcFiles=fileDir.listFiles();
  29. String[]stemFileNames=newString[srcFiles.length];
  30. for(inti=0;i<srcFiles.length;i++){
  31. StringfileFullName=srcFiles[i].getCanonicalPath();
  32. StringfileShortName=srcFiles[i].getName();
  33. if(!newFile(fileFullName).isDirectory()){//确认子文件名不是目录如果是可以再次递归调用
  34. System.out.println("Beginpreprocess:"+fileFullName);
  35. StringBuilderstringBuilder=newStringBuilder();
  36. stringBuilder.append(dirTarget+"/"+fileShortName);
  37. createProcessFile(fileFullName,stringBuilder.toString());
  38. stemFileNames[i]=stringBuilder.toString();
  39. }
  40. else{
  41. fileFullName=fileFullName.replace("\\","/");
  42. doProcess(fileFullName);
  43. }
  44. }
  45. //下面调用stem算法
  46. if(stemFileNames.length>0&&stemFileNames[0]!=null){
  47. Stemmer.porterMain(stemFileNames);
  48. }
  49. }
  50. /**进行文本预处理生成目标文件
  51. *@paramsrcDir源文件文件目录的绝对路径
  52. *@paramtargetDir生成的目标文件的绝对路径
  53. *@throwsIOException
  54. */
  55. privatestaticvoidcreateProcessFile(StringsrcDir,StringtargetDir)throwsIOException{
  56. //TODOAuto-generatedmethodstub
  57. FileReadersrcFileReader=newFileReader(srcDir);
  58. FileReaderstopWordsReader=newFileReader("F:/DataMiningSample/stopwords.txt");
  59. FileWritertargetFileWriter=newFileWriter(targetDir);
  60. BufferedReadersrcFileBR=newBufferedReader(srcFileReader);//装饰模式
  61. BufferedReaderstopWordsBR=newBufferedReader(stopWordsReader);
  62. Stringline,resLine,stopWordsLine;
  63. //用stopWordsBR够着停用词的ArrayList容器
  64. ArrayList<String>stopWordsArray=newArrayList<String>();
  65. while((stopWordsLine=stopWordsBR.readLine())!=null){
  66. if(!stopWordsLine.isEmpty()){
  67. stopWordsArray.add(stopWordsLine);
  68. }
  69. }
  70. while((line=srcFileBR.readLine())!=null){
  71. resLine=lineProcess(line,stopWordsArray);
  72. if(!resLine.isEmpty()){
  73. //按行写,一行写一个单词
  74. String[]tempStr=resLine.split("");//\s
  75. for(inti=0;i<tempStr.length;i++){
  76. if(!tempStr[i].isEmpty()){
  77. targetFileWriter.append(tempStr[i]+"\n");
  78. }
  79. }
  80. }
  81. }
  82. targetFileWriter.flush();
  83. targetFileWriter.close();
  84. srcFileReader.close();
  85. stopWordsReader.close();
  86. srcFileBR.close();
  87. stopWordsBR.close();
  88. }
  89. /**对每行字符串进行处理,主要是词法分析、去停用词和stemming
  90. *@paramline待处理的一行字符串
  91. *@paramArrayList<String>停用词数组
  92. *@returnString处理好的一行字符串,是由处理好的单词重新生成,以空格为分隔符
  93. *@throwsIOException
  94. */
  95. privatestaticStringlineProcess(Stringline,ArrayList<String>stopWordsArray)throwsIOException{
  96. //TODOAuto-generatedmethodstub
  97. //step1英文词法分析,去除数字、连字符、标点符号、特殊字符,所有大写字母转换成小写,可以考虑用正则表达式
  98. Stringres[]=line.split("[^a-zA-Z]");
  99. //这里要小心,防止把有单词中间有数字和连字符的单词截断了,但是截断也没事
  100. StringresString=newString();
  101. //step2去停用词
  102. //step3stemming,返回后一起做
  103. for(inti=0;i<res.length;i++){
  104. if(!res[i].isEmpty()&&!stopWordsArray.contains(res[i].toLowerCase())){
  105. resString+=""+res[i].toLowerCase()+"";
  106. }
  107. }
  108. returnresString;
  109. }
  110. /**
  111. *@paramargs
  112. *@throwsIOException
  113. */
  114. publicvoidBPPMain(String[]args)throwsIOException{
  115. //TODOAuto-generatedmethodstub
  116. DataPreProcessdataPrePro=newDataPreProcess();
  117. dataPrePro.doProcess("F:/DataMiningSample/orginSample");
  118. }
  119. }
steming的porter算法可以Google,有C及JAVA的实现版本,点击下载 porter算法JAVA版本

2、特征词的选取
首先统计经过预处理后在所有文档中出现不重复的单词一共有87554个,对这些词进行统计发现:
出现次数大于等于1次的词有87554个
出现次数大于等于3次的词有36456个
出现次数大于等于4次的词有30095个
特征词的选取策略:
策略一:保留所有词作为特征词 共计87554个
策略二:选取出现次数大于等于4次的词作为特征词共计30095个
特征词的选取策略:采用策略一,后面将对两种特征词选取策略的计算时间和平均准确率做对比
测试集与训练集的创建类CreateTrainAndTestSample.java如下
[java] view plain copy 在CODE上查看代码片 派生到我的代码片
  1. packagecom.pku.yangliu;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.FileWriter;
  6. importjava.io.IOException;
  7. importjava.util.SortedMap;
  8. importjava.util.TreeMap;
  9. /**创建训练样例集合与测试样例集合
  10. *
  11. */
  12. publicclassCreateTrainAndTestSample{
  13. voidfilterSpecialWords()throwsIOException{
  14. //TODOAuto-generatedmethodstub
  15. Stringword;
  16. ComputeWordsVectorcwv=newComputeWordsVector();
  17. StringfileDir="F:/DataMiningSample/processedSample_includeNotSpecial";
  18. SortedMap<String,Double>wordMap=newTreeMap<String,Double>();
  19. wordMap=cwv.countWords(fileDir,wordMap);
  20. cwv.printWordMap(wordMap);//把wordMap输出到文件
  21. File[]sampleDir=newFile(fileDir).listFiles();
  22. for(inti=0;i<sampleDir.length;i++){
  23. File[]sample=sampleDir[i].listFiles();
  24. StringtargetDir="F:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName();
  25. FiletargetDirFile=newFile(targetDir);
  26. if(!targetDirFile.exists()){
  27. targetDirFile.mkdir();
  28. }
  29. for(intj=0;j<sample.length;j++){
  30. StringfileShortName=sample[j].getName();
  31. if(fileShortName.contains("stemed")){
  32. targetDir="F:/DataMiningSample/processedSampleOnlySpecial/"+sampleDir[i].getName()+"/"+fileShortName.substring(0,5);
  33. FileWritertgWriter=newFileWriter(targetDir);
  34. FileReadersamReader=newFileReader(sample[j]);
  35. BufferedReadersamBR=newBufferedReader(samReader);
  36. while((word=samBR.readLine())!=null){
  37. if(wordMap.containsKey(word)){
  38. tgWriter.append(word+"\n");
  39. }
  40. }
  41. tgWriter.flush();
  42. tgWriter.close();
  43. }
  44. }
  45. }
  46. }
  47. voidcreateTestSamples(StringfileDir,doubletrainSamplePercent,intindexOfSample,StringclassifyResultFile)throwsIOException{
  48. //TODOAuto-generatedmethodstub
  49. Stringword,targetDir;
  50. FileWritercrWriter=newFileWriter(classifyResultFile);//测试样例正确类目记录文件
  51. File[]sampleDir=newFile(fileDir).listFiles();
  52. for(inti=0;i<sampleDir.length;i++){
  53. File[]sample=sampleDir[i].listFiles();
  54. doubletestBeginIndex=indexOfSample*(sample.length*(1-trainSamplePercent));//测试样例的起始文件序号
  55. doubletestEndIndex=(indexOfSample+1)*(sample.length*(1-trainSamplePercent));//测试样例集的结束文件序号
  56. for(intj=0;j<sample.length;j++){
  57. FileReadersamReader=newFileReader(sample[j]);
  58. BufferedReadersamBR=newBufferedReader(samReader);
  59. StringfileShortName=sample[j].getName();
  60. StringsubFileName=fileShortName;
  61. if(j>testBeginIndex&&j<testEndIndex){//序号在规定区间内的作为测试样本,需要为测试样本生成类别-序号文件,最后加入分类的结果,一行对应一个文件,方便统计准确率
  62. targetDir="F:/DataMiningSample/TestSample"+indexOfSample+"/"+sampleDir[i].getName();
  63. crWriter.append(subFileName+""+sampleDir[i].getName()+"\n");
  64. }
  65. else{//其余作为训练样本
  66. targetDir="F:/DataMiningSample/TrainSample"+indexOfSample+"/"+sampleDir[i].getName();
  67. }
  68. targetDir=targetDir.replace("\\","/");
  69. FiletrainSamFile=newFile(targetDir);
  70. if(!trainSamFile.exists()){
  71. trainSamFile.mkdir();
  72. }
  73. targetDir+="/"+subFileName;
  74. FileWritertsWriter=newFileWriter(newFile(targetDir));
  75. while((word=samBR.readLine())!=null){
  76. tsWriter.append(word+"\n");
  77. }
  78. tsWriter.flush();
  79. tsWriter.close();
  80. }
  81. }
  82. crWriter.flush();
  83. crWriter.close();
  84. }
  85. }

3、贝叶斯算法描述及实现
根据朴素贝叶斯公式,每个测试样例属于某个类别的概率 = 所有测试样例包含特征词类条件概率P(tk|c)之积 * 先验概率P(c)
在具体计算类条件概率和先验概率时,朴素贝叶斯分类器有两种模型
(1)多元分布模型( multinomial model ) –以单词为粒度,也就是说,考虑每个文件里面重复出现多次的单词。注意多项分布其实是从二项分布拓展出来的,如果采用多项分布模型,那么每个单词表示变量就不再是二值变量(出现/不出现),而是每个单词在文件中出现的次数
类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+训练样本中不重复特征词总数)
先验概率P(c)=类c下的单词总数/整个训练样本的单词总数
(2)伯努利模型(Bernoulli model) –以文件为粒度,或者说是采用二项分布模型,伯努利实验即N次独立重复随机实验,只考虑事件发生/不发生,所以每个单词的表示变量是布尔型的
类条件概率P(tk|c)=(类c下包含单词tk的文件数+1)/(类c下文件总数+2)(注意:开始此处错写成了单词,多谢网友提醒后更正)
先验概率P(c)=类c下文件总数/整个训练样本的文件总数
本分类器选用多元分布模型计算,根据《Introduction to Information Retrieval 》,多元分布模型计算准确率更高
贝叶斯算法的实现有以下注意点:
(1) 计算概率用到了BigDecimal类实现任意精度计算
(2) 用交叉验证法做十次分类实验,对准确率取平均值
(3) 根据正确类目文件和分类结果文计算混淆矩阵并且输出
(4) Map<String,Double> cateWordsProb key为“类目_单词”, value为该类目下该单词的出现次数,避免重复计算
贝叶斯算法实现类如下 NaiveBayesianClassifier.java
[java] view plain copy 在CODE上查看代码片 派生到我的代码片
  1. packagecom.pku.yangliu;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.FileWriter;
  6. importjava.io.IOException;
  7. importjava.math.BigDecimal;
  8. importjava.util.Iterator;
  9. importjava.util.Map;
  10. importjava.util.Set;
  11. importjava.util.SortedSet;
  12. importjava.util.TreeMap;
  13. importjava.util.TreeSet;
  14. importjava.util.Vector;
  15. /**利用朴素贝叶斯算法对newsgroup文档集做分类,采用十组交叉测试取平均值
  16. *采用多项式模型,stanford信息检索导论课件上面言多项式模型比伯努利模型准确度高
  17. *类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+1)/(类c下单词总数+|V|)
  18. *
  19. */
  20. publicclassNaiveBayesianClassifier{
  21. /**用贝叶斯法对测试文档集分类
  22. *@paramtrainDir训练文档集目录
  23. *@paramtestDir测试文档集目录
  24. *@paramclassifyResultFileNew分类结果文件路径
  25. *@throwsException
  26. */
  27. privatevoiddoProcess(StringtrainDir,StringtestDir,
  28. StringclassifyResultFileNew)throwsException{
  29. //TODOAuto-generatedmethodstub
  30. Map<String,Double>cateWordsNum=newTreeMap<String,Double>();//保存训练集每个类别的总词数
  31. Map<String,Double>cateWordsProb=newTreeMap<String,Double>();//保存训练样本每个类别中每个属性词的出现词数
  32. cateWordsProb=getCateWordsProb(trainDir);
  33. cateWordsNum=getCateWordsNum(trainDir);
  34. doubletotalWordsNum=0.0;//记录所有训练集的总词数
  35. Set<Map.Entry<String,Double>>cateWordsNumSet=cateWordsNum.entrySet();
  36. for(Iterator<Map.Entry<String,Double>>it=cateWordsNumSet.iterator();it.hasNext();){
  37. Map.Entry<String,Double>me=it.next();
  38. totalWordsNum+=me.getValue();
  39. }
  40. //下面开始读取测试样例做分类
  41. Vector<String>testFileWords=newVector<String>();
  42. Stringword;
  43. File[]testDirFiles=newFile(testDir).listFiles();
  44. FileWritercrWriter=newFileWriter(classifyResultFileNew);
  45. for(inti=0;i<testDirFiles.length;i++){
  46. File[]testSample=testDirFiles[i].listFiles();
  47. for(intj=0;j<testSample.length;j++){
  48. testFileWords.clear();
  49. FileReaderspReader=newFileReader(testSample[j]);
  50. BufferedReaderspBR=newBufferedReader(spReader);
  51. while((word=spBR.readLine())!=null){
  52. testFileWords.add(word);
  53. }
  54. //下面分别计算该测试样例属于二十个类别的概率
  55. File[]trainDirFiles=newFile(trainDir).listFiles();
  56. BigDecimalmaxP=newBigDecimal(0);
  57. StringbestCate=null;
  58. for(intk=0;k<trainDirFiles.length;k++){
  59. BigDecimalp=computeCateProb(trainDirFiles[k],testFileWords,cateWordsNum,totalWordsNum,cateWordsProb);
  60. if(k==0){
  61. maxP=p;
  62. bestCate=trainDirFiles[k].getName();
  63. continue;
  64. }
  65. if(p.compareTo(maxP)==1){
  66. maxP=p;
  67. bestCate=trainDirFiles[k].getName();
  68. }
  69. }
  70. crWriter.append(testSample[j].getName()+""+bestCate+"\n");
  71. crWriter.flush();
  72. }
  73. }
  74. crWriter.close();
  75. }
  76. /**统计某类训练样本中每个单词的出现次数
  77. *@paramstrDir训练样本集目录
  78. *@returnMap<String,Double>cateWordsProb用"类目_单词"对来索引的map,保存的val就是该类目下该单词的出现次数
  79. *@throwsIOException
  80. */
  81. publicMap<String,Double>getCateWordsProb(StringstrDir)throwsIOException{
  82. Map<String,Double>cateWordsProb=newTreeMap<String,Double>();
  83. FilesampleFile=newFile(strDir);
  84. File[]sampleDir=sampleFile.listFiles();
  85. Stringword;
  86. for(inti=0;i<sampleDir.length;i++){
  87. File[]sample=sampleDir[i].listFiles();
  88. for(intj=0;j<sample.length;j++){
  89. FileReadersamReader=newFileReader(sample[j]);
  90. BufferedReadersamBR=newBufferedReader(samReader);
  91. while((word=samBR.readLine())!=null){
  92. Stringkey=sampleDir[i].getName()+"_"+word;
  93. if(cateWordsProb.containsKey(key)){
  94. doublecount=cateWordsProb.get(key)+1.0;
  95. cateWordsProb.put(key,count);
  96. }
  97. else{
  98. cateWordsProb.put(key,1.0);
  99. }
  100. }
  101. }
  102. }
  103. returncateWordsProb;
  104. }
  105. /**计算某一个测试样本属于某个类别的概率
  106. *@paramMap<String,Double>cateWordsProb记录每个目录中出现的单词及次数
  107. *@paramFiletrainFile该类别所有的训练样本所在目录
  108. *@paramVector<String>testFileWords该测试样本中的所有词构成的容器
  109. *@paramdoubletotalWordsNum记录所有训练样本的单词总数
  110. *@paramMap<String,Double>cateWordsNum记录每个类别的单词总数
  111. *@returnBigDecimal返回该测试样本在该类别中的概率
  112. *@throwsException
  113. *@throwsIOException
  114. */
  115. privateBigDecimalcomputeCateProb(FiletrainFile,Vector<String>testFileWords,Map<String,Double>cateWordsNum,doubletotalWordsNum,Map<String,Double>cateWordsProb)throwsException{
  116. //TODOAuto-generatedmethodstub
  117. BigDecimalprobability=newBigDecimal(1);
  118. doublewordNumInCate=cateWordsNum.get(trainFile.getName());
  119. BigDecimalwordNumInCateBD=newBigDecimal(wordNumInCate);
  120. BigDecimaltotalWordsNumBD=newBigDecimal(totalWordsNum);
  121. for(Iterator<String>it=testFileWords.iterator();it.hasNext();){
  122. Stringme=it.next();
  123. Stringkey=trainFile.getName()+"_"+me;
  124. doubletestFileWordNumInCate;
  125. if(cateWordsProb.containsKey(key)){
  126. testFileWordNumInCate=cateWordsProb.get(key);
  127. }elsetestFileWordNumInCate=0.0;
  128. BigDecimaltestFileWordNumInCateBD=newBigDecimal(testFileWordNumInCate);
  129. BigDecimalxcProb=(testFileWordNumInCateBD.add(newBigDecimal(0.0001))).divide(totalWordsNumBD.add(wordNumInCateBD),10,BigDecimal.ROUND_CEILING);
  130. probability=probability.multiply(xcProb);
  131. }
  132. BigDecimalres=probability.multiply(wordNumInCateBD.divide(totalWordsNumBD,10,BigDecimal.ROUND_CEILING));
  133. returnres;
  134. }
  135. /**获得每个类目下的单词总数
  136. *@paramtrainDir训练文档集目录
  137. *@returnMap<String,Double><目录名,单词总数>的map
  138. *@throwsIOException
  139. */
  140. privateMap<String,Double>getCateWordsNum(StringtrainDir)throwsIOException{
  141. //TODOAuto-generatedmethodstub
  142. Map<String,Double>cateWordsNum=newTreeMap<String,Double>();
  143. File[]sampleDir=newFile(trainDir).listFiles();
  144. for(inti=0;i<sampleDir.length;i++){
  145. doublecount=0;
  146. File[]sample=sampleDir[i].listFiles();
  147. for(intj=0;j<sample.length;j++){
  148. FileReaderspReader=newFileReader(sample[j]);
  149. BufferedReaderspBR=newBufferedReader(spReader);
  150. while(spBR.readLine()!=null){
  151. count++;
  152. }
  153. }
  154. cateWordsNum.put(sampleDir[i].getName(),count);
  155. }
  156. returncateWordsNum;
  157. }
  158. /**根据正确类目文件和分类结果文件统计出准确率
  159. *@paramclassifyResultFile正确类目文件
  160. *@paramclassifyResultFileNew分类结果文件
  161. *@returndouble分类的准确率
  162. *@throwsIOException
  163. */
  164. doublecomputeAccuracy(StringclassifyResultFile,
  165. StringclassifyResultFileNew)throwsIOException{
  166. //TODOAuto-generatedmethodstub
  167. Map<String,String>rightCate=newTreeMap<String,String>();
  168. Map<String,String>resultCate=newTreeMap<String,String>();
  169. rightCate=getMapFromResultFile(classifyResultFile);
  170. resultCate=getMapFromResultFile(classifyResultFileNew);
  171. Set<Map.Entry<String,String>>resCateSet=resultCate.entrySet();
  172. doublerightCount=0.0;
  173. for(Iterator<Map.Entry<String,String>>it=resCateSet.iterator();it.hasNext();){
  174. Map.Entry<String,String>me=it.next();
  175. if(me.getValue().equals(rightCate.get(me.getKey()))){
  176. rightCount++;
  177. }
  178. }
  179. computerConfusionMatrix(rightCate,resultCate);
  180. returnrightCount/resultCate.size();
  181. }
  182. /**根据正确类目文件和分类结果文计算混淆矩阵并且输出
  183. *@paramrightCate正确类目对应map
  184. *@paramresultCate分类结果对应map
  185. *@returndouble分类的准确率
  186. *@throwsIOException
  187. */
  188. privatevoidcomputerConfusionMatrix(Map<String,String>rightCate,
  189. Map<String,String>resultCate){
  190. //TODOAuto-generatedmethodstub
  191. int[][]confusionMatrix=newint[20][20];
  192. //首先求出类目对应的数组索引
  193. SortedSet<String>cateNames=newTreeSet<String>();
  194. Set<Map.Entry<String,String>>rightCateSet=rightCate.entrySet();
  195. for(Iterator<Map.Entry<String,String>>it=rightCateSet.iterator();it.hasNext();){
  196. Map.Entry<String,String>me=it.next();
  197. cateNames.add(me.getValue());
  198. }
  199. cateNames.add("rec.sport.baseball");//防止数少一个类目
  200. String[]cateNamesArray=cateNames.toArray(newString[0]);
  201. Map<String,Integer>cateNamesToIndex=newTreeMap<String,Integer>();
  202. for(inti=0;i<cateNamesArray.length;i++){
  203. cateNamesToIndex.put(cateNamesArray[i],i);
  204. }
  205. for(Iterator<Map.Entry<String,String>>it=rightCateSet.iterator();it.hasNext();){
  206. Map.Entry<String,String>me=it.next();
  207. confusionMatrix[cateNamesToIndex.get(me.getValue())][cateNamesToIndex.get(resultCate.get(me.getKey()))]++;
  208. }
  209. //输出混淆矩阵
  210. double[]hangSum=newdouble[20];
  211. System.out.print("");
  212. for(inti=0;i<20;i++){
  213. System.out.print(i+"");
  214. }
  215. System.out.println();
  216. for(inti=0;i<20;i++){
  217. System.out.print(i+"");
  218. for(intj=0;j<20;j++){
  219. System.out.print(confusionMatrix[i][j]+"");
  220. hangSum[i]+=confusionMatrix[i][j];
  221. }
  222. System.out.println(confusionMatrix[i][i]/hangSum[i]);
  223. }
  224. System.out.println();
  225. }
  226. /**从分类结果文件中读取map
  227. *@paramclassifyResultFileNew类目文件
  228. *@returnMap<String,String>由<文件名,类目名>保存的map
  229. *@throwsIOException
  230. */
  231. privateMap<String,String>getMapFromResultFile(
  232. StringclassifyResultFileNew)throwsIOException{
  233. //TODOAuto-generatedmethodstub
  234. FilecrFile=newFile(classifyResultFileNew);
  235. FileReadercrReader=newFileReader(crFile);
  236. BufferedReadercrBR=newBufferedReader(crReader);
  237. Map<String,String>res=newTreeMap<String,String>();
  238. String[]s;
  239. Stringline;
  240. while((line=crBR.readLine())!=null){
  241. s=line.split("");
  242. res.put(s[0],s[1]);
  243. }
  244. returnres;
  245. }
  246. /**
  247. *@paramargs
  248. *@throwsException
  249. */
  250. publicvoidNaiveBayesianClassifierMain(String[]args)throwsException{
  251. //TODOAuto-generatedmethodstub
  252. //首先创建训练集和测试集
  253. CreateTrainAndTestSamplectt=newCreateTrainAndTestSample();
  254. NaiveBayesianClassifiernbClassifier=newNaiveBayesianClassifier();
  255. ctt.filterSpecialWords();//根据包含非特征词的文档集生成只包含特征词的文档集到processedSampleOnlySpecial目录下
  256. double[]accuracyOfEveryExp=newdouble[10];
  257. doubleaccuracyAvg,sum=0;
  258. for(inti=0;i<10;i++){//用交叉验证法做十次分类实验,对准确率取平均值
  259. StringTrainDir="F:/DataMiningSample/TrainSample"+i;
  260. StringTestDir="F:/DataMiningSample/TestSample"+i;
  261. StringclassifyRightCate="F:/DataMiningSample/classifyRightCate"+i+".txt";
  262. StringclassifyResultFileNew="F:/DataMiningSample/classifyResultNew"+i+".txt";
  263. ctt.createTestSamples("F:/DataMiningSample/processedSampleOnlySpecial",0.9,i,classifyRightCate);
  264. nbClassifier.doProcess(TrainDir,TestDir,classifyResultFileNew);
  265. accuracyOfEveryExp[i]=nbClassifier.computeAccuracy(classifyRightCate,classifyResultFileNew);
  266. System.out.println("TheaccuracyforNaiveBayesianClassifierin"+i+"thExpis:"+accuracyOfEveryExp[i]);
  267. }
  268. for(inti=0;i<10;i++){
  269. sum+=accuracyOfEveryExp[i];
  270. }
  271. accuracyAvg=sum/10;
  272. System.out.println("TheaverageaccuracyforNaiveBayesianClassifierinallExpsis:"+accuracyAvg);
  273. }
  274. }

4 朴素贝叶斯算法对newsgroup文档集做分类的结果

为方便计算混淆矩阵,将类目编号如下

0 alt.atheism
1 comp.graphics
2 comp.os.ms-windows.misc
3comp.sys.ibm.pc.hdwar
4comp.sys.mac.hardwar
5 comp.windows.x
6 misc.forsale
7 rec.autos
8 rec.motorcycles
9 rec.sport.baseball
10 rec.sport.hockey
11 sci.crypt
12 sci.electronics
13 sci.med
14 sci.space
15 soc.religion.christian
16 talk.politics.guns
17 talk.politics.mideast
18 talk.politics.misc
19 talk.religion.misc

贝叶斯算法分类结果-混淆矩阵表示,以交叉验证的第6次实验结果为例,分类准确率达到80.47%
程序运行硬件环境:Intel Core 2 Duo CPU T5750 2GHZ, 2G内存,实验结果如下
取所有词共87554个作为特征词:10次交叉验证实验平均准确率78.19%,用时23min,准确率范围75.65%-80.47%,第6次实验准确率超过80%
取出现次数大于等于4次的词共计30095个作为特征词: 10次交叉验证实验平均准确率77.91%,用时22min,准确率范围75.51%-80.26%,第6次实验准确率超过80%
结论:朴素贝叶斯算法不必去除出现次数很低的词,因为出现次数很低的词的IDF比较 大,去除后分类准确率下降,而计算时间并没有显著减少
5 贝叶斯算法的改进
为了进一步提高贝叶斯算法的分类准确率,可以考虑
(1) 优化特征词的选取策略
(2)改进多项式模型的类条件概率的计算公式,改进为类条件概率P(tk|c)=(类c下单词tk在各个文档中出现过的次数之和+0.001)/(类c下单词总数+训练样本中不重复特征词总数),分子当tk没有出现时,只加0.001,这样更加精确的描述的词的统计分布规律,做此改进后的混淆矩阵如下

可以看到第6次分组实验的准确率提高到84.79%,第7词分组实验的准确率达到85.24%,平均准确率由77.91%提高到了82.23%,优化效果还是很明显的
KNN算法描述及JAVA实现,和两种算法的准确率对比,见数据挖掘- 基于贝叶斯算法及KNN算法的newsgroup18828文档分类器的JAVA实现(下)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值