余弦定理实现新闻自动分类算法

前言

余弦定理,这个在初中课本中就出现过的公式,恐怕没有人不知道的吧。但是另外一个概念,可能不是很多的人会听说过,他叫空间向量,一般用e表示,高中课本中有专门讲过这个东西,有了余弦定理和向量空间,我们就可以做许多有意思的事情了,利用余弦定理计算文本相似度的算法就是其中一个很典型的例子。当然这个话题太老,说的人太多,没有什么新意,恰巧周末阅读了吴军博士的<<数学之美>>这门书,书中讲到了利用余弦定理实现新闻分类,于是就索性完成这个算法的初步模型。感兴趣的可以继续往下看。

算法背景

在以往,如果对一则新闻进行归类,一般使用的都是人工分类的办法,大体上看一下标题和首尾两段文字,就能知道新闻是属于财经的,体育的又或者是健康类的。但是在当今信息爆炸的时代,这显然是不可能完成的任务,所以我们急切的相用机器自己帮我们”分类“。最好的形式是我给计算机提供大量的已分类好的数据,等强大的计算机大脑训练好了这个分类模型,后边的事情就是他来完成了。看起来这好像很高深,很困难的样子,但是其实我们自己也可以写一个,只是效果可能不会那么好。

分类器实现原理

新闻自动分类器实现的本质也是利用余弦定理比较文本的相似度,于是这个问题的难点就在于这个特征向量哪里来,怎么去获得。特征向量,特征向量,关键两个字在于特征,新闻的特征就在于他的关键词,我的简单理解就是专业性的词语,换句话说,就是属于某类新闻特有的词语,比如金融类的新闻,关键词一般就是股票啊,公司啊,上市啊等等词语。这些词的寻找可以通过统计词频的方式实现,最后统计出来的关键词,进行降序排列,一个关键词就代表一个新的维度。 那么新的问题又来了,我要统计词频,那么就得首先进行分词,要把每个新闻句子的主谓宾统统挖掘出来啊,好像这个工作比我整个算法还要复杂的样子。OK,其实已经有人已经帮我们把这个问题解决了,在这个算法中我使用的是中科大的ICTCLAS分词系统,效果非常棒,举个例子,下面是我原始的新闻内容:

  1. 教育部副部长:教育公平是社会公平重要基础
  2. 723日,教育部党组副书记、副部长杜玉波为全国学联全体代表作《教育综合改革与青年学生成长成才》的专题报告。中国青年网记者张炎良摄
  3. 人民网北京724日电(记者贺迎春实习生王斯慧

经过分词系统处理后的分词效果:

  1. 教育部/nt副/b部长/n:/wm教育/v公平/an是/vshi社会/n公平/a重要/a基础/n
  2. 7月/t23日/t,/wd教育部/nt党组/n副/b书记/n、/wn副/b部长/n杜玉波/nr为/p全国学联/nt全体/n代表作/n《/wkz教育/vn综合/vn改革/vn与/cc青年/n学生/n成长/vi成才/vi》/wky的/ude1专题/n报告/n。/wj中国/ns青年/n网/n记者/n张/q炎/ng良/d摄/vg
  3. 人民/n网/n北京/ns7月/t24日/t电/n(/wkz记者/n贺/vg迎春/n实习生/n王斯慧/nr)/wky昨日/t,/wd教育部/nt副/b部长

OK,有了这个分词的结果之后,后面的事情就水到渠成了。

算法的实现步骤

1、给定训练的新闻数据集。

2、通过分词系统统计词频的方式,统计词频最高的N位作为特征词,即特征向量

3、输入测试数据,同样统计词频,并于训练数据的进行商的操作,得到特征向量值

4、最后利用余弦定理计算相似度,并与最小阈值做比较。

算法的代码实现

ICTCLAS工具类ICTCLAS.java:

  1. packageNewsClassify;
  2. importjava.io.File;
  3. importjava.io.FileOutputStream;
  4. importjava.io.InputStream;
  5. importjava.util.StringTokenizer;
  6. publicclassICTCLAS50{
  7. static{
  8. try{
  9. Stringlibpath=System.getProperty("user.dir")+"\\lib";
  10. Stringpath=null;
  11. StringTokenizerst=newStringTokenizer(libpath,
  12. System.getProperty("path.separator"));
  13. if(st.hasMoreElements()){
  14. path=st.nextToken();
  15. }
  16. //copyalldllfilestojavalibpath
  17. FiledllFile=null;
  18. InputStreaminputStream=null;
  19. FileOutputStreamoutputStream=null;
  20. byte[]array=null;
  21. dllFile=newFile(newFile(path),"ICTCLAS50.dll");
  22. if(!dllFile.exists()){
  23. inputStream=ICTCLAS50.class.getResource("/lib/ICTCLAS50.dll")
  24. .openStream();
  25. outputStream=newFileOutputStream(dllFile);
  26. array=newbyte[1024];
  27. for(inti=inputStream.read(array);i!=-1;i=inputStream
  28. .read(array)){
  29. outputStream.write(array,0,i);
  30. }
  31. outputStream.close();
  32. }
  33. }catch(Exceptione){
  34. e.printStackTrace();
  35. }
  36. try{
  37. //loadJniCall.dll
  38. System.loadLibrary("ICTCLAS50");
  39. System.out.println("4444");
  40. }catch(Errore){
  41. e.printStackTrace();
  42. }
  43. }
  44. publicnativebooleanICTCLAS_Init(byte[]sPath);
  45. publicnativebooleanICTCLAS_Exit();
  46. publicnativeintICTCLAS_ImportUserDictFile(byte[]sPath,inteCodeType);
  47. publicnativeintICTCLAS_SaveTheUsrDic();
  48. publicnativeintICTCLAS_SetPOSmap(intnPOSmap);
  49. publicnativebooleanICTCLAS_FileProcess(byte[]sSrcFilename,
  50. inteCodeType,intbPOSTagged,byte[]sDestFilename);
  51. publicnativebyte[]ICTCLAS_ParagraphProcess(byte[]sSrc,inteCodeType,
  52. intbPOSTagged);
  53. publicnativebyte[]nativeProcAPara(byte[]sSrc,inteCodeType,
  54. intbPOStagged);
  55. }
新闻实体类New.java

  1. packageNewsClassify;
  2. /**
  3. *词语实体类
  4. *
  5. *@authorlyq
  6. *
  7. */
  8. publicclassWordimplementsComparable<Word>{
  9. //词语名称
  10. Stringname;
  11. //词频
  12. Integercount;
  13. publicWord(Stringname,Integercount){
  14. this.name=name;
  15. this.count=count;
  16. }
  17. @Override
  18. publicintcompareTo(Wordo){
  19. //TODOAuto-generatedmethodstub
  20. returno.count.compareTo(this.count);
  21. }
  22. }
分类算法类NewsClassify.java:

  1. packageNewsClassify;
  2. importjava.io.BufferedReader;
  3. importjava.io.File;
  4. importjava.io.FileReader;
  5. importjava.io.IOException;
  6. importjava.util.ArrayList;
  7. importjava.util.Collections;
  8. /**
  9. *分类算法模型
  10. *
  11. *@authorlyq
  12. *
  13. */
  14. publicclassNewsClassifyTool{
  15. //余弦向量空间维数
  16. privateintvectorNum;
  17. //余弦相似度最小满足阈值
  18. privatedoubleminSupportValue;
  19. //当前训练数据的新闻类别
  20. privateStringnewsType;
  21. //训练新闻数据文件地址
  22. privateArrayList<String>trainDataPaths;
  23. publicNewsClassifyTool(ArrayList<String>trainDataPaths,StringnewsType,
  24. intvectorNum,doubleminSupportValue){
  25. this.trainDataPaths=trainDataPaths;
  26. this.newsType=newsType;
  27. this.vectorNum=vectorNum;
  28. this.minSupportValue=minSupportValue;
  29. }
  30. /**
  31. *从文件中读取数据
  32. */
  33. privateStringreadDataFile(StringfilePath){
  34. Filefile=newFile(filePath);
  35. StringBuilderstrBuilder=null;
  36. try{
  37. BufferedReaderin=newBufferedReader(newFileReader(file));
  38. Stringstr;
  39. strBuilder=newStringBuilder();
  40. while((str=in.readLine())!=null){
  41. strBuilder.append(str);
  42. }
  43. in.close();
  44. }catch(IOExceptione){
  45. e.getStackTrace();
  46. }
  47. returnstrBuilder.toString();
  48. }
  49. /**
  50. *计算测试数据的特征向量
  51. */
  52. privatedouble[]calCharacterVectors(StringfilePath){
  53. intindex;
  54. double[]vectorDimensions;
  55. double[]temp;
  56. Newsnews;
  57. NewstestNews;
  58. StringnewsCotent;
  59. StringtestContent;
  60. StringparseContent;
  61. //高频词汇
  62. ArrayList<Word>frequentWords;
  63. ArrayList<Word>wordList;
  64. testContent=readDataFile(filePath);
  65. testNews=newNews(testContent);
  66. parseNewsContent(filePath);
  67. index=filePath.indexOf('.');
  68. parseContent=readDataFile(filePath.substring(0,index)+"-split.txt");
  69. testNews.statWords(parseContent);
  70. vectorDimensions=newdouble[vectorNum];
  71. //计算训练数据集的类别的特征向量
  72. for(Stringpath:this.trainDataPaths){
  73. newsCotent=readDataFile(path);
  74. news=newNews(newsCotent);
  75. //进行分词操作
  76. index=path.indexOf('.');
  77. parseNewsContent(path);
  78. parseContent=readDataFile(path.substring(0,index)+"-split.txt");
  79. news.statWords(parseContent);
  80. wordList=news.wordDatas;
  81. //将词频统计结果降序排列
  82. Collections.sort(wordList);
  83. frequentWords=newArrayList<Word>();
  84. //截取出前vectorDimens的词语
  85. for(inti=0;i<vectorNum;i++){
  86. frequentWords.add(wordList.get(i));
  87. }
  88. temp=testNews.calVectorDimension(frequentWords);
  89. //将特征向量值进行累加
  90. for(inti=0;i<vectorDimensions.length;i++){
  91. vectorDimensions[i]+=temp[i];
  92. }
  93. }
  94. //最后取平均向量值作为最终的特征向量值
  95. for(inti=0;i<vectorDimensions.length;i++){
  96. vectorDimensions[i]/=trainDataPaths.size();
  97. }
  98. returnvectorDimensions;
  99. }
  100. /**
  101. *根据求得的向量空间计算余弦相似度值
  102. *
  103. *@paramvectorDimension
  104. *已求得的测试数据的特征向量值
  105. *@return
  106. */
  107. privatedoublecalCosValue(double[]vectorDimension){
  108. doubleresult;
  109. doublenum1;
  110. doublenum2;
  111. doubletemp1;
  112. doubletemp2;
  113. //标准的特征向量,每个维度上都为1
  114. double[]standardVector;
  115. standardVector=newdouble[vectorNum];
  116. for(inti=0;i<vectorNum;i++){
  117. standardVector[i]=1;
  118. }
  119. temp1=0;
  120. temp2=0;
  121. num1=0;
  122. for(inti=0;i<vectorNum;i++){
  123. //累加分子的值
  124. num1+=vectorDimension[i]*standardVector[i];
  125. //累加分母的值
  126. temp1+=vectorDimension[i]*vectorDimension[i];
  127. temp2+=standardVector[i]*standardVector[i];
  128. }
  129. num2=Math.sqrt(temp1)*Math.sqrt(temp2);
  130. //套用余弦定理公式进行计算
  131. result=num1/num2;
  132. returnresult;
  133. }
  134. /**
  135. *进行新闻分类
  136. *
  137. *@paramfilePath
  138. *测试新闻数据文件地址
  139. */
  140. publicvoidnewsClassify(StringfilePath){
  141. doubleresult;
  142. double[]vectorDimension;
  143. vectorDimension=calCharacterVectors(filePath);
  144. result=calCosValue(vectorDimension);
  145. //如果余弦相似度值满足最小阈值要求,则属于目标分类
  146. if(result>=minSupportValue){
  147. System.out.println(String.format("最终相似度结果为%s,大于阈值%s,所以此新闻属于%s类新闻",
  148. result,minSupportValue,newsType));
  149. }else{
  150. System.out.println(String.format("最终相似度结果为%s,小于阈值%s,所以此新闻不属于%s类新闻",
  151. result,minSupportValue,newsType));
  152. }
  153. }
  154. /**
  155. *利用分词系统进行新闻内容的分词
  156. *
  157. *@paramsrcPath
  158. *新闻文件路径
  159. */
  160. privatevoidparseNewsContent(StringsrcPath){
  161. //TODOAuto-generatedmethodstub
  162. intindex;
  163. StringdirApi;
  164. StringdesPath;
  165. dirApi=System.getProperty("user.dir")+"\\lib";
  166. //组装输出路径值
  167. index=srcPath.indexOf('.');
  168. desPath=srcPath.substring(0,index)+"-split.txt";
  169. try{
  170. ICTCLAS50testICTCLAS50=newICTCLAS50();
  171. //分词所需库的路径、初始化
  172. if(testICTCLAS50.ICTCLAS_Init(dirApi.getBytes("GB2312"))==false){
  173. System.out.println("InitFail!");
  174. return;
  175. }
  176. //将文件名string类型转为byte类型
  177. byte[]Inputfilenameb=srcPath.getBytes();
  178. //分词处理后输出文件名、将文件名string类型转为byte类型
  179. byte[]Outputfilenameb=desPath.getBytes();
  180. //文件分词(第一个参数为输入文件的名,第二个参数为文件编码类型,第三个参数为是否标记词性集1yes,0
  181. //no,第四个参数为输出文件名)
  182. testICTCLAS50.ICTCLAS_FileProcess(Inputfilenameb,0,1,
  183. Outputfilenameb);
  184. //退出分词器
  185. testICTCLAS50.ICTCLAS_Exit();
  186. }catch(Exceptionex){
  187. ex.printStackTrace();
  188. }
  189. }
  190. }
场景测试了Client.java:

  1. packageNewsClassify;
  2. importjava.util.ArrayList;
  3. /**
  4. *新闻分类算法测试类
  5. *@authorlyq
  6. *
  7. */
  8. publicclassClient{
  9. publicstaticvoidmain(String[]args){
  10. StringtestFilePath1;
  11. StringtestFilePath2;
  12. StringtestFilePath3;
  13. Stringpath;
  14. StringnewsType;
  15. intvectorNum;
  16. doubleminSupportValue;
  17. ArrayList<String>trainDataPaths;
  18. NewsClassifyToolclassifyTool;
  19. //添加测试以及训练集数据文件路径
  20. testFilePath1="C:\\Users\\lyq\\Desktop\\icon\\test\\testNews1.txt";
  21. testFilePath2="C:\\Users\\lyq\\Desktop\\icon\\test\\testNews2.txt";
  22. testFilePath3="C:\\Users\\lyq\\Desktop\\icon\\test\\testNews3.txt";
  23. trainDataPaths=newArrayList<String>();
  24. path="C:\\Users\\lyq\\Desktop\\icon\\test\\trainNews1.txt";
  25. trainDataPaths.add(path);
  26. path="C:\\Users\\lyq\\Desktop\\icon\\test\\trainNews2.txt";
  27. trainDataPaths.add(path);
  28. newsType="金融";
  29. vectorNum=10;
  30. minSupportValue=0.45;
  31. classifyTool=newNewsClassifyTool(trainDataPaths,newsType,vectorNum,minSupportValue);
  32. classifyTool.newsClassify(testFilePath1);
  33. classifyTool.newsClassify(testFilePath2);
  34. classifyTool.newsClassify(testFilePath3);
  35. }
  36. }
结果输出:

  1. 最终相似度结果为0.39999999999999997,小于阈值0.45,所以此新闻不属于金融类新闻
  2. 最终相似度结果为0.4635393084189425,大于阈值0.45,所以此新闻属于金融类新闻
  3. 最终相似度结果为0.661835948543857,大于阈值0.45,所以此新闻属于金融类新闻
测试数据以及全部代码,链接在此:https://github.com/linyiqun/news-classifier
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值