最近在做一个新闻搜索引擎时,需要对爬虫爬下来的网页,实现自动分类。基于文本分类实现的算法,常见的有VSM、贝叶斯、TF-IDF、KNN、决策树等等几种方式。其中,贝叶斯在大型的数据集上表现出来速度和准确度还挺不错的。
采用贝叶斯,实现文本分类的思路为:
1、计算词条在其所属分类中出现的概率:词频 = 某词条出现的次数 / 词条所属分类的词条总数
2、对于测试文本d,首先对其分词,然后按下面的公式计算该文本属于分类Cj的概率
文档出现在分类Cj 的概率 = (词条1 / Cj 分类下的文本数 * 词条2 / Cj 分类下的文本数.... 词条N / Cj 分类下的文本数) * (Cj 分类下的文本数 / 所有分类的总文本数)
3、将文本分到概率最大的那个类别中。
实验数据使用的是搜狗实验室的文本分类语料库,下载地址为http://www.sogou.com/labs/dl/c.html ,大小在30M左右,有9个分类,每个分类下有2000个左右的文本。文本的分词使用的是IK分词。 由于实验数据量比较大,就使用了序列化的方式,把第一次计算的结果保存下来,以便后续使用。实现代码如下:
package unit.test;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileFilter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.io.StringReader;
import java.util.ArrayList;
import java.util.Date;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.lucene.analysis.Analyzer;
import org.wltea.analyzer.core.IKSegmenter;
import org.wltea.analyzer.core.Lexeme;
import org.wltea.analyzer.lucene.IKAnalyzer;
/**
* @Description: 文本自动分类
* @author houqirui
*/
public class AutoCategory implements Serializable{
private static final long serialVersionUID = 1L;
// 序列化
public final static String SERIALIZABLE_PATH = "C:\\test\\SogouC.reduced\\Reduced\\category.ser";
// public final static String SERIALIZABLE_PATH = "C:\\qirui\\data\\SogouC.reduced\\Reduced\\Train.ser";
// 训练集
private String trainPath = "C:\\test\\SogouC.reduced\\Reduced";
// 分类名称
private Map<String, String> categoryMap;
// 每个分类的概率
private Map<String, Double> categoryFreqMap;
// Map<分类, Map<term, freq>>
private Map<String, Map<String, Double>> freqMap;
// 总文件数
private long totalSize;
// 每个分类下的文件数
private Map<String, Long> categoryFileNums;
// 分词器
public transient Analyzer analyzer;
public static AutoCategory instance;
public static AutoCategory getIntstance() {
// 读取序列化在磁盘上的本类对象
FileInputStream fis = null;
ObjectInputStream oos = null;
try {
File file = new File(SERIALIZABLE_PATH);
if(file.length() != 0) {
fis = new FileInputStream(SERIALIZABLE_PATH);
oos = new ObjectInputStream(fis);
instance = (AutoCategory) oos.readObject();
instance.analyzer = new IKAnalyzer(true);
} else {
instance = new AutoCategory();
}
} catch (ClassNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
if(oos != null) {
oos.close();
}
if(fis != null) {
fis.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
return instance;
}
/**
* 读取文本文件内容
* @param filePath
* @return
*/
public String readTxt(String filePath) {
BufferedReader br = null;
StringBuilder sb = null;
try {
br = new BufferedReader(new FileReader(filePath));
sb = new StringBuilder();
String line = br.readLine();
while(line != null) {
sb.append(line);
line = br.readLine();
}
return sb.toString();
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
} finally {
try {
if(br != null) {
br.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
return "";
}
public void buildCategoryFreq() {
// 初始化
freqMap = new HashMap<String, Map<String, Double>>();
categoryMap = new HashMap<String, String>();
categoryFreqMap = new HashMap<String, Double>();
categoryFileNums = new HashMap<String, Long>();
categoryMap.put("C000008", "财经");
categoryMap.put("C000010", "IT");
categoryMap.put("C000013", "健康");
categoryMap.put("C000014", "体育");
categoryMap.put("C000016", "旅游");
categoryMap.put("C000020", "教育");
categoryMap.put("C000022", "招聘");
categoryMap.put("C000023", "文化");
categoryMap.put("C000024", "军事");
// 总文件数
totalSize = getFilesNum();
// 计算各个类别的样本数
Set<String> keySet = categoryMap.keySet();
for(String category : keySet) {
File f = new File(trainPath + File.separator + category);
File[] files = f.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
if (pathname.getName().endsWith(".txt")) {
return true;
}
return false;
}
});
// 存放每个词条的出现次数
Map<String, Double> termMap = new HashMap<String, Double>();
if(files != null) {
for (File txt : files) {
String content = readTxt(txt.getAbsolutePath());
// 分词
List<String> words = participle(content, false);
// 统计每个词出现的次数
for (String word : words) {
if (termMap.containsKey(word)) {
Double wordCount = termMap.get(word);
termMap.put(word, wordCount + 1);
} else {
termMap.put(word, 1.0);
}
}
}
long fileNums = files.length;
freqMap.put(category, calcTermFreq(termMap, fileNums));
categoryFileNums.put(category, fileNums);
categoryFreqMap.put(category, fileNums / Double.valueOf(totalSize));
}
}
// 把计算结果序列化到本地 (空间换时间)
FileOutputStream fos = null;
ObjectOutputStream oos = null;
try {
fos = new FileOutputStream(SERIALIZABLE_PATH);
oos = new ObjectOutputStream(fos);
oos.writeObject(this);
} catch (Exception e) {
e.printStackTrace();
} finally {
try {
if(oos != null) {
oos.close();
}
if(fos != null) {
fos.close();
}
} catch (IOException e) {
e.printStackTrace();
}
}
}
/**
* 计算每个词条在其类别下出现的概率
* @param termMap 每个词条出现的次数
* @param fileNums 某一类别下的文件数
* @return
*/
public Map<String, Double> calcTermFreq(Map<String, Double> termMap, long fileNums) {
Map<String, Double> termFreqMap = new HashMap<String, Double>();
Iterator<String> termIterator = termMap.keySet().iterator();
String term = null;
Double termNum = null;
Double termFreq = null; // 假设每个词条在该类别下的出现的概率为1
while(termIterator.hasNext()) {
term = termIterator.next();
termNum = termMap.get(term);
// 计算该词条在该类别下出现的概率,如果这个词条在该类别下不存在,就给定一个极小的值,不影响计算
termFreq = (termNum == null) ? ((double) 1 / (fileNums + 1)) : (termNum / fileNums);
termFreqMap.put(term, termFreq);
}
return termFreqMap;
}
/**
* 统计训练集的总文件数
* @return
*/
public long getFilesNum() {
long counter = 0;
Set<String> keySet = categoryMap.keySet();
for(String category : keySet) {
File f = new File(trainPath + File.separator + category);
File[] files = f.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
if (pathname.getName().endsWith(".txt")) {
return true;
}
return false;
}
});
counter = counter + files.length;
}
return counter;
}
/**
* 分词
* @param content 要分词的文本内容
* @param isRepeat 分词后是否允许有重复的词条
* @return 分词后的词条列表
*/
public List<String> participle(String content, boolean isRepeat) {
if(StringUtils.isBlank(content)) {
return null;
}
try {
IKSegmenter ik = new IKSegmenter(new StringReader(content), true);
Lexeme lex=null;
List<String> words = new ArrayList<String>();
// 是否允许重复
if(isRepeat) {
while((lex = ik.next())!=null){
words.add(lex.getLexemeText());
}
} else {
String term = "";
while((lex = ik.next())!=null){
term = lex.getLexemeText();
if(words.contains(term)) {
continue;
}
words.add(term);
}
}
return words;
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
/**
* 获得文档的分类
* @param text
*/
public Map<String, Double> getCategory(String text) {
// 分词,并且去重
List<String> text_words = participle(text, true);
Map<String, Double> frequencyOfType = new HashMap<String, Double>();
Set<String> keySet = categoryMap.keySet();
for(String cateKey : keySet) {
double docFreq = 1.0; // 假设文档在每个分类下的出现的概率为1
Map<String, Double> termFreq = freqMap.get(cateKey);
double term_frequency = 0.0;
Double tf = null;
long cateFileNums = categoryFileNums.get(cateKey);
for (String word : text_words) {
// 获得该词条在该类别下出现的概率,如果这个词条在该类别下不存在,就给定一个极小的值,不影响计算
tf = termFreq.get(word);
term_frequency = tf == null ? ((double) 1 / cateFileNums) : tf;
// 文档出现在类别的概率, 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
// 当double无限小的时候会归为0,为了避免 *10
docFreq = docFreq * term_frequency * 10;
docFreq = ((docFreq == 0.0) ? Double.MIN_VALUE : docFreq);
}
docFreq = ((docFreq == 1.0) ? 0.0 : docFreq);
// 此类别在所有类别中所占概率
double classOfAll = categoryFreqMap.get(cateKey);
// 根据贝叶斯公式 P(A|B)= P(B|A)*P(A)/P(B),由于P(B)是常数,在这里不做计算,不影响分类结果
frequencyOfType.put(cateKey, docFreq * classOfAll);
}
return frequencyOfType;
}
public static void main(String[] args) {
System.out.println("start time: " + new Date());
AutoCategory category = AutoCategory.getIntstance();
category.buildCategoryFreq();
// String txt = "时报讯 昨天是五一黄金周的最后一天,游客们纷纷踏上了回家的旅程,宁波各大景区全面“退烧”。而此时,宁波的各大餐饮商场负责人却喜笑颜开。";
String txt = "消息:北约谴责成员国领导人同意在9月威尔士峰会期间减少东欧的长期驻军";
Map<String, Double> map = category.getCategory(txt);
Set<String> keys = map.keySet();
// Double typeVal = Double.MIN_VALUE;
for(String key : keys) {
System.out.println(key + " : " + map.get(key));
}
System.out.println("-----------------------");
String txt2 = "人民网8月18日讯 综合外媒报道,8月17日有消息指出,“伊斯兰国”在叙利亚阿勒颇与“自由军”激战时,捕获一名来自日本的“自由战士”。目前,“伊斯兰国”已在“推特”上确认,他们已将此人处决。 据驻约旦的日本叙利亚大使披露,日方是在16日晚些时候得知有日方人员在叙利亚被捕的消息的,日方人员在接受采访时称,日本使馆是从第三方线人处取得这一消息的,目前日方正在确认被抓人员的确切身份,并在努力尝试营救活动。但分析人士即指出,在心狠手辣的“伊斯兰国”武装手中,这名人质恐怕凶多吉少。 另据《读卖新闻》等日本权威媒体17日爆料,这名日本男子是在16日于叙利亚阿勒颇战区附近被“伊斯兰国”武装人员捕获的,当时这名男子正携带包括摄影器材在内的众多装备与“叙利亚自由军”武装人员一同行动,后在与“伊斯兰国”武装的冲突中被打散并抓获。(老任)";
Map<String, Double> map2 = category.getCategory(txt2);
Set<String> keys2 = map2.keySet();
// Double typeVal = Double.MIN_VALUE;
for(String key : keys2) {
System.out.println(key + " : " + map2.get(key));
}
System.out.println("end time: " + new Date());
}
}