/**
* 训练器
*
* @author duyf
*
*/
class Train implements Serializable {
/**
*
*/
private static final long serialVersionUID = 1L;
public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";
// 训练集的位置
private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";
// 类别序号对应的实际名称
private Map classMap = new HashMap();
// 类别对应的txt文本数
private Map classP = new ConcurrentHashMap();
// 所有文本数
private AtomicInteger actCount = new AtomicInteger(0);
// 每个类别对应的词典和频数
private Map> classWordMap = new ConcurrentHashMap>();
// 分词器
private transient Participle participle;
private static Train trainInstance = new Train();
public static Train getInstance() {
trainInstance = new Train();
// 读取序列化在硬盘的本类对象
FileInputStream fis;
try {
File f = new File(SERIALIZABLE_PATH);
if (f.length() != 0) {
fis = new FileInputStream(SERIALIZABLE_PATH);
ObjectInputStream oos = new ObjectInputStream(fis);
trainInstance = (Train) oos.readObject();
trainInstance.participle = new IkParticiple();
} else {
trainInstance = new Train();
}
} catch (Exception e) {
e.printStackTrace();
}
return trainInstance;
}
private Train() {
this.participle = new IkParticiple();
}
public String readtxt(String path) {
BufferedReader br = null;
StringBuilder str = null;
try {
br = new BufferedReader(new FileReader(path));
str = new StringBuilder();
String r = br.readLine();
while (r != null) {
str.append(r);
r = br.readLine();
}
return str.toString();
} catch (IOException ex) {
ex.printStackTrace();
} finally {
if (br != null) {
try {
br.close();
} catch (IOException e) {
e.printStackTrace();
}
}
str = null;
br = null;
}
return "";
}
/**
* 训练数据
*/
public void realTrain() {
// 初始化
classMap = new HashMap();
classP = new HashMap();
actCount.set(0);
classWordMap = new HashMap>();
// classMap.put("C000007", "汽车");
classMap.put("C000008", "财经");
classMap.put("C000010", "IT");
classMap.put("C000013", "健康");
classMap.put("C000014", "体育");
classMap.put("C000016", "旅游");
classMap.put("C000020", "教育");
classMap.put("C000022", "招聘");
classMap.put("C000023", "文化");
classMap.put("C000024", "军事");
// 计算各个类别的样本数
Set keySet = classMap.keySet();
// 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df
final Set allWords = new HashSet();
// 存放每个类别的文件词汇内容
final Map> classContentMap = new ConcurrentHashMap>();
for (String classKey : keySet) {
Participle participle = new IkParticiple();
Map wordMap = new HashMap();
File f = new File(trainPath + File.separator + classKey);
File[] files = f.listFiles(new FileFilter() {
@Override
public boolean accept(File pathname) {
if (pathname.getName().endsWith(".txt")) {
return true;
}
return false;
}
});
// 存储每个类别的文件词汇向量
List fileContent = new ArrayList();
if (files != null) {
for (File txt : files) {
String content = readtxt(txt.getAbsolutePath());
// 分词
String[] word_arr = participle.participle(content, false);
fileContent.add(word_arr);
// 统计每个词出现的个数
for (String word : word_arr) {
if (wordMap.containsKey(word)) {
Double wordCount = wordMap.get(word);
wordMap.put(word, wordCount + 1);
} else {
wordMap.put(word, 1.0);
}
}
}
}
// 每个类别对应的词典和频数
classWordMap.put(classKey, wordMap);
// 每个类别的文章数目
classP.put(classKey, files.length);
actCount.addAndGet(files.length);
classContentMap.put(classKey, fileContent);
}
// 把训练好的训练器对象序列化到本地 (空间换时间)
FileOutputStream fos;
try {
fos = new FileOutputStream(SERIALIZABLE_PATH);
ObjectOutputStream oos = new ObjectOutputStream(fos);
oos.writeObject(this);
} catch (Exception e) {
e.printStackTrace();
}
}
/**
* 分类
*
* @param text
* @return 返回各个类别的概率大小
*/
public Map classify(String text) {
// 分词,并且去重
String[] text_words = participle.participle(text, false);
Map frequencyOfType = new HashMap();
Set keySet = classMap.keySet();
for (String classKey : keySet) {
double typeOfThis = 1.0;
Map wordMap = classWordMap.get(classKey);
for (String word : text_words) {
Double wordCount = wordMap.get(word);
int articleCount = classP.get(classKey);
/*
* Double wordidf = idfMap.get(word); if(wordidf==null){
* wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); }
*/
// 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算
double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1))
: (wordCount / articleCount);
// 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
// 当double无限小的时候会归为0,为了避免 *10
typeOfThis = typeOfThis * term_frequency * 10;
typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE
: typeOfThis);
// System.out.println(typeOfThis+" : "+term_frequency+" :
// "+actCount);
}
typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);
// 此类别文章出现的概率
double classOfAll = classP.get(classKey) / actCount.doubleValue();
// 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果
frequencyOfType.put(classKey, typeOfThis * classOfAll);
}
return frequencyOfType;
}
public void pringAll() {
Set>> classWordEntry = classWordMap
.entrySet();
for (Entry> ent : classWordEntry) {
System.out.println("类别: " + ent.getKey());
Map wordMap = ent.getValue();
Set> wordMapSet = wordMap.entrySet();
for (Entry wordEnt : wordMapSet) {
System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());
}
}
}
public Map getClassMap() {
return classMap;
}
public void setClassMap(Map classMap) {
this.classMap = classMap;
}
}
在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。
By 阿飞哥 转载请说明
腾讯微博:http://t.qq.com/duyunfeiRoom
新浪微博:http://weibo.com/u/1766094735
分享到:
2012-09-25 15:15
浏览 12264
分类:互联网
评论
3 楼
njthnet
2016-06-07
Participle 和 IkParticiple 这2个类找不到,能给个提示吗?
2 楼
u010402518
2015-10-15
分类还是可行的,如果学习的在多一点那就会更准了。
给大家一个调用的例子
Train train = Train.getInstance();
// 训练,训练好模型之后序列化到磁盘就不用再次训练了
//train.realTrain();
Map resultMap = train.classify("胡润研究院今日发布《胡润百富榜》,61岁的王健林及其家族以2200亿财富超过马云,第二次成为中国首富,财富比去年增长52%。大陆十亿美金富豪人数首度超越美国,达596位。51岁的马云及其家族以1450亿元退居第二,财富比去年减少3%");
train.pringAll();
1 楼
u010402518
2015-10-15
文章写的不错,思路很清晰,终于找到了一篇可以用的文章 。