根据贝叶斯定理实现的新闻自动分类

最近在做一个新闻搜索引擎时,需要对爬虫爬下来的网页,实现自动分类。基于文本分类实现的算法,常见的有VSM、贝叶斯、TF-IDF、KNN、决策树等等几种方式。其中,贝叶斯在大型的数据集上表现出来速度和准确度还挺不错的。

采用贝叶斯,实现文本分类的思路为:

1、计算词条在其所属分类中出现的概率:词频 = 某词条出现的次数 / 词条所属分类的词条总数

2、对于测试文本d,首先对其分词,然后按下面的公式计算该文本属于分类Cj的概率

文档出现在分类Cj  的概率 = (词条1 / Cj 分类下的文本数 * 词条2 / Cj 分类下的文本数.... 词条N / Cj 分类下的文本数) * (Cj 分类下的文本数 / 所有分类的总文本数)

3、将文本分到概率最大的那个类别中。


实验数据使用的是搜狗实验室的文本分类语料库,下载地址为http://www.sogou.com/labs/dl/c.html ,大小在30M左右,有9个分类,每个分类下有2000个左右的文本。文本的分词使用的是IK分词。 由于实验数据量比较大,就使用了序列化的方式,把第一次计算的结果保存下来,以便后续使用。实现代码如下:

package unit.test;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang3.StringUtils;
import org.apache.lucene.analysis.Analyzer;
import org.wltea.analyzer.core.IKSegmenter;
import org.wltea.analyzer.core.Lexeme;
import org.wltea.analyzer.lucene.IKAnalyzer;

/**
 * @Description: 文本自动分类 
 * @author houqirui
 */
public class AutoCategory implements Serializable{

	private static final long serialVersionUID = 1L;
	// 序列化
	public final static String SERIALIZABLE_PATH = "C:\\test\\SogouC.reduced\\Reduced\\category.ser";
	// public final static String SERIALIZABLE_PATH = "C:\\qirui\\data\\SogouC.reduced\\Reduced\\Train.ser";
	// 训练集
	private String trainPath = "C:\\test\\SogouC.reduced\\Reduced";
	// 分类名称
	private Map<String, String> categoryMap;
	// 每个分类的概率
	private Map<String, Double> categoryFreqMap;
	// Map<分类, Map<term, freq>>
	private Map<String, Map<String, Double>> freqMap;
	// 总文件数
	private long totalSize;
	// 每个分类下的文件数
	private Map<String, Long> categoryFileNums;
	// 分词器
    public transient Analyzer analyzer;
    
    public static AutoCategory instance;
    
    public static AutoCategory getIntstance() {
    	// 读取序列化在磁盘上的本类对象
    	FileInputStream fis = null;
    	ObjectInputStream oos = null;
    	try {
			File file = new File(SERIALIZABLE_PATH);
			if(file.length() != 0) {
				fis = new FileInputStream(SERIALIZABLE_PATH);
				oos = new ObjectInputStream(fis);
			    instance = (AutoCategory) oos.readObject();
			    instance.analyzer = new IKAnalyzer(true);
			} else {
				instance = new AutoCategory();
			}
		} catch (ClassNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			try {
				if(oos != null) {
					oos.close();
				}
				
				if(fis != null) {
					fis.close();
				}
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
    	
    	return instance;
    }
    
    /**
     * 读取文本文件内容
     * @param filePath
     * @return
     */
    public String readTxt(String filePath) {
    	BufferedReader br = null;
    	StringBuilder sb = null;
    	
    	try {
			br = new BufferedReader(new FileReader(filePath));
			sb = new StringBuilder();
			
			String line = br.readLine();
			while(line != null) {
				sb.append(line);
				line = br.readLine();
			}
			
			return sb.toString();
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			try {
				if(br != null) {
					br.close();
				}
			} catch (IOException e) {
				e.printStackTrace();
			}
		}
    	
    	return "";
    }
    
    public void buildCategoryFreq() {
    	// 初始化
    	freqMap = new HashMap<String, Map<String, Double>>(); 
    	categoryMap = new HashMap<String, String>();
    	categoryFreqMap = new HashMap<String, Double>();
    	categoryFileNums = new HashMap<String, Long>();
    	
    	categoryMap.put("C000008", "财经");
    	categoryMap.put("C000010", "IT");
    	categoryMap.put("C000013", "健康");
    	categoryMap.put("C000014", "体育");
    	categoryMap.put("C000016", "旅游");
    	categoryMap.put("C000020", "教育");
    	categoryMap.put("C000022", "招聘");
    	categoryMap.put("C000023", "文化");
    	categoryMap.put("C000024", "军事");
        
    	// 总文件数
    	totalSize = getFilesNum();
    	
        // 计算各个类别的样本数
        Set<String> keySet = categoryMap.keySet();
        
        for(String category : keySet) {
            File f = new File(trainPath + File.separator + category);
            File[] files = f.listFiles(new FileFilter() {            	 
                @Override
                public boolean accept(File pathname) {
                    if (pathname.getName().endsWith(".txt")) {
                        return true;
                    }
                    return false;
                }
            });
            
            // 存放每个词条的出现次数
            Map<String, Double> termMap = new HashMap<String, Double>();
            if(files != null) {
            	
            	for (File txt : files) {
            		String content = readTxt(txt.getAbsolutePath());
            		// 分词
            		List<String> words = participle(content, false);
            		
            		// 统计每个词出现的次数
                    for (String word : words) {
                        if (termMap.containsKey(word)) {
                            Double wordCount = termMap.get(word);
                            termMap.put(word, wordCount + 1);
                        } else {
                        	termMap.put(word, 1.0);
                        }
                    }
            	}
            	
            	long fileNums = files.length;
            	freqMap.put(category, calcTermFreq(termMap, fileNums));
            	categoryFileNums.put(category, fileNums);
            	categoryFreqMap.put(category, fileNums / Double.valueOf(totalSize));            
            }
        }
        
        // 把计算结果序列化到本地 (空间换时间)
        FileOutputStream fos = null;
        ObjectOutputStream oos = null;
        try {
            fos = new FileOutputStream(SERIALIZABLE_PATH);
            oos = new ObjectOutputStream(fos);
            oos.writeObject(this);
        } catch (Exception e) {
            e.printStackTrace();
        } finally {
        	try {
				if(oos != null) {
					oos.close();
				}
				if(fos != null) {
					fos.close();
				}
			} catch (IOException e) {
				e.printStackTrace();
			}
        }
    }
    
    /**
     * 计算每个词条在其类别下出现的概率
     * @param termMap	每个词条出现的次数
     * @param fileNums	某一类别下的文件数
     * @return
     */
    public Map<String, Double> calcTermFreq(Map<String, Double> termMap, long fileNums) {
    	Map<String, Double> termFreqMap = new HashMap<String, Double>();
    	
    	Iterator<String> termIterator = termMap.keySet().iterator();
    	
    	String term = null;
    	Double termNum = null;
    	Double termFreq = null;	// 假设每个词条在该类别下的出现的概率为1
    	 
    	while(termIterator.hasNext()) {
    		term = termIterator.next();
    		termNum = termMap.get(term);
    		
    		// 计算该词条在该类别下出现的概率,如果这个词条在该类别下不存在,就给定一个极小的值,不影响计算
    		termFreq = (termNum == null) ? ((double) 1 / (fileNums + 1)) : (termNum / fileNums);
			termFreqMap.put(term, termFreq);
    	}
    	
    	return termFreqMap;
    }
    
    /**
     * 统计训练集的总文件数
     * @return
     */
    public long getFilesNum() {
    	long counter = 0;
    	
    	Set<String> keySet = categoryMap.keySet();
        
        for(String category : keySet) {
	    	File f = new File(trainPath + File.separator + category);
	        File[] files = f.listFiles(new FileFilter() {            	 
	            @Override
	            public boolean accept(File pathname) {
	                if (pathname.getName().endsWith(".txt")) {
	                    return true;
	                }
	                return false;
	            }
	        });
	        
	        counter = counter + files.length;
        }
    	
    	return counter;
    }
    
    /**
     * 分词
     * @param content 要分词的文本内容
     * @param isRepeat	分词后是否允许有重复的词条
     * @return	分词后的词条列表
     */
    public List<String> participle(String content, boolean isRepeat) {
    	if(StringUtils.isBlank(content)) {
    		return null;
    	}
    	
    	try {
			IKSegmenter ik = new IKSegmenter(new StringReader(content), true);
			Lexeme lex=null; 
			List<String> words = new ArrayList<String>();
			
			// 是否允许重复			
			if(isRepeat) {
				while((lex = ik.next())!=null){  
					words.add(lex.getLexemeText());
				}
			} else {
				String term = "";
				while((lex = ik.next())!=null){  
					term = lex.getLexemeText();
					if(words.contains(term)) {
						continue;
					}
					words.add(term);
				}
			}
			
			return words;
		} catch (IOException e) {
			e.printStackTrace();
		} 
    	
    	return null;
    }
    
    /**
     * 获得文档的分类
     * @param text
     */
    public Map<String, Double> getCategory(String text) {
    	// 分词,并且去重
    	List<String> text_words = participle(text, true);
    	
    	Map<String, Double> frequencyOfType = new HashMap<String, Double>();
    	Set<String> keySet = categoryMap.keySet();
    	
    	for(String cateKey : keySet) {
    		double docFreq = 1.0;	// 假设文档在每个分类下的出现的概率为1
    		
    		Map<String, Double> termFreq = freqMap.get(cateKey);
    		double term_frequency = 0.0;
    		Double tf = null;
    		long cateFileNums = categoryFileNums.get(cateKey);
    		
    		for (String word : text_words) {
    			// 获得该词条在该类别下出现的概率,如果这个词条在该类别下不存在,就给定一个极小的值,不影响计算
    			tf = termFreq.get(word);
    			term_frequency = tf == null ? ((double) 1 / cateFileNums) : tf;
    			// 文档出现在类别的概率, 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
                // 当double无限小的时候会归为0,为了避免 *10
    			docFreq = docFreq * term_frequency * 10;
    			docFreq = ((docFreq == 0.0) ? Double.MIN_VALUE : docFreq);
    		}
    		
    		docFreq = ((docFreq == 1.0) ? 0.0 : docFreq);
    		
    		// 此类别在所有类别中所占概率
            double classOfAll = categoryFreqMap.get(cateKey);
            
            // 根据贝叶斯公式 P(A|B)= P(B|A)*P(A)/P(B),由于P(B)是常数,在这里不做计算,不影响分类结果
            frequencyOfType.put(cateKey, docFreq * classOfAll);
    	}
    	
    	return frequencyOfType;
    }
    
    public static void main(String[] args) {
    	System.out.println("start time: " + new Date());
    	AutoCategory category = AutoCategory.getIntstance();
    	category.buildCategoryFreq();
    	
    	// String txt = "时报讯 昨天是五一黄金周的最后一天,游客们纷纷踏上了回家的旅程,宁波各大景区全面“退烧”。而此时,宁波的各大餐饮商场负责人却喜笑颜开。";
    	String txt = "消息:北约谴责成员国领导人同意在9月威尔士峰会期间减少东欧的长期驻军";
    	
    	Map<String, Double> map = category.getCategory(txt);
    	Set<String> keys = map.keySet();
    	// Double typeVal = Double.MIN_VALUE;
    	for(String key : keys) {
    		System.out.println(key + " : " + map.get(key));
    	}
    	System.out.println("-----------------------");
    	
    	String txt2 = "人民网8月18日讯 综合外媒报道,8月17日有消息指出,“伊斯兰国”在叙利亚阿勒颇与“自由军”激战时,捕获一名来自日本的“自由战士”。目前,“伊斯兰国”已在“推特”上确认,他们已将此人处决。  据驻约旦的日本叙利亚大使披露,日方是在16日晚些时候得知有日方人员在叙利亚被捕的消息的,日方人员在接受采访时称,日本使馆是从第三方线人处取得这一消息的,目前日方正在确认被抓人员的确切身份,并在努力尝试营救活动。但分析人士即指出,在心狠手辣的“伊斯兰国”武装手中,这名人质恐怕凶多吉少。  另据《读卖新闻》等日本权威媒体17日爆料,这名日本男子是在16日于叙利亚阿勒颇战区附近被“伊斯兰国”武装人员捕获的,当时这名男子正携带包括摄影器材在内的众多装备与“叙利亚自由军”武装人员一同行动,后在与“伊斯兰国”武装的冲突中被打散并抓获。(老任)";
    	Map<String, Double> map2 = category.getCategory(txt2);
    	Set<String> keys2 = map2.keySet();
    	// Double typeVal = Double.MIN_VALUE;
    	for(String key : keys2) {
    		System.out.println(key + " : " + map2.get(key));
    	}
    	System.out.println("end time: " + new Date());
    }
}


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值