昨天实现了一个基于贝叶斯定理的的文本分类,贝叶斯定理假设特征属性(在文本中就是词汇)对待分类项的影响都是独立的,道理比较简单,在中文分类系统中,分类的准确性与分词系统的好坏有很大的关系,这段代码也是试验不同分词系统才顺手写的一个。
试验数据用的sogou实验室的文本分类样本,一共分为9个类别,每个类别文件夹下大约有2000篇文章。由于文本数据量确实较大,所以得想办法让每次训练的结果都能保存起来,以便于下次直接使用,我这里使用序列化的方式保存在硬盘。
训练代码如下:
- /**
- * 训练器
- *
- * @author duyf
- *
- */
- class Train implements Serializable {
-
- /**
- *
- */
- private static final long serialVersionUID = 1L;
-
- public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";
- // 训练集的位置
- private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";
-
- // 类别序号对应的实际名称
- private Map<String, String> classMap = new HashMap<String, String>();
-
- // 类别对应的txt文本数
- private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>();
-
- // 所有文本数
- private AtomicInteger actCount = new AtomicInteger(0);
-
-
-
- // 每个类别对应的词典和频数
- private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>();
-
- // 分词器
- private transient Participle participle;
-
- private static Train trainInstance = new Train();
-
- public static Train getInstance() {
- trainInstance = new Train();
-
- // 读取序列化在硬盘的本类对象
- FileInputStream fis;
- try {
- File f = new File(SERIALIZABLE_PATH);
- if (f.length() != 0) {
- fis = new FileInputStream(SERIALIZABLE_PATH);
- ObjectInputStream oos = new ObjectInputStream(fis);
- trainInstance = (Train) oos.readObject();
- trainInstance.participle = new IkParticiple();
- } else {
- trainInstance = new Train();
- }
- } catch (Exception e) {
- e.printStackTrace();
- }
-
- return trainInstance;
- }
-
- private Train() {
- this.participle = new IkParticiple();
- }
-
- public String readtxt(String path) {
- BufferedReader br = null;
- StringBuilder str = null;
- try {
- br = new BufferedReader(new FileReader(path));
-
- str = new StringBuilder();
-
- String r = br.readLine();
-
- while (r != null) {
- str.append(r);
- r = br.readLine();
-
- }
-
- return str.toString();
- } catch (IOException ex) {
- ex.printStackTrace();
- } finally {
- if (br != null) {
- try {
- br.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- str = null;
- br = null;
- }
-
- return "";
- }
-
- /**
- * 训练数据
- */
- public void realTrain() {
- // 初始化
- classMap = new HashMap<String, String>();
- classP = new HashMap<String, Integer>();
- actCount.set(0);
- classWordMap = new HashMap<String, Map<String, Double>>();
-
- // classMap.put("C000007", "汽车");
- classMap.put("C000008", "财经");
- classMap.put("C000010", "IT");
- classMap.put("C000013", "健康");
- classMap.put("C000014", "体育");
- classMap.put("C000016", "旅游");
- classMap.put("C000020", "教育");
- classMap.put("C000022", "招聘");
- classMap.put("C000023", "文化");
- classMap.put("C000024", "军事");
-
- // 计算各个类别的样本数
- Set<String> keySet = classMap.keySet();
-
- // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df
- final Set<String> allWords = new HashSet<String>();
-
- // 存放每个类别的文件词汇内容
- final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>();
-
- for (String classKey : keySet) {
-
- Participle participle = new IkParticiple();
- Map<String, Double> wordMap = new HashMap<String, Double>();
- File f = new File(trainPath + File.separator + classKey);
- File[] files = f.listFiles(new FileFilter() {
-
- @Override
- public boolean accept(File pathname) {
- if (pathname.getName().endsWith(".txt")) {
- return true;
- }
- return false;
- }
-
- });
-
- // 存储每个类别的文件词汇向量
- List<String[]> fileContent = new ArrayList<String[]>();
- if (files != null) {
- for (File txt : files) {
- String content = readtxt(txt.getAbsolutePath());
- // 分词
- String[] word_arr = participle.participle(content, false);
- fileContent.add(word_arr);
- // 统计每个词出现的个数
- for (String word : word_arr) {
- if (wordMap.containsKey(word)) {
- Double wordCount = wordMap.get(word);
- wordMap.put(word, wordCount + 1);
- } else {
- wordMap.put(word, 1.0);
- }
-
- }
- }
- }
-
- // 每个类别对应的词典和频数
- classWordMap.put(classKey, wordMap);
-
- // 每个类别的文章数目
- classP.put(classKey, files.length);
- actCount.addAndGet(files.length);
- classContentMap.put(classKey, fileContent);
-
- }
-
-
-
-
-
- // 把训练好的训练器对象序列化到本地 (空间换时间)
- FileOutputStream fos;
- try {
- fos = new FileOutputStream(SERIALIZABLE_PATH);
- ObjectOutputStream oos = new ObjectOutputStream(fos);
- oos.writeObject(this);
- } catch (Exception e) {
- e.printStackTrace();
- }
-
- }
-
- /**
- * 分类
- *
- * @param text
- * @return 返回各个类别的概率大小
- */
- public Map<String, Double> classify(String text) {
- // 分词,并且去重
- String[] text_words = participle.participle(text, false);
-
- Map<String, Double> frequencyOfType = new HashMap<String, Double>();
- Set<String> keySet = classMap.keySet();
- for (String classKey : keySet) {
- double typeOfThis = 1.0;
- Map<String, Double> wordMap = classWordMap.get(classKey);
- for (String word : text_words) {
- Double wordCount = wordMap.get(word);
- int articleCount = classP.get(classKey);
-
- /*
- * Double wordidf = idfMap.get(word); if(wordidf==null){
- * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); }
- */
-
- // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算
- double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1))
- : (wordCount / articleCount);
-
- // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
- // 当double无限小的时候会归为0,为了避免 *10
-
- typeOfThis = typeOfThis * term_frequency * 10;
- typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE
- : typeOfThis);
- // System.out.println(typeOfThis+" : "+term_frequency+" :
- // "+actCount);
- }
-
- typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);
-
- // 此类别文章出现的概率
- double classOfAll = classP.get(classKey) / actCount.doubleValue();
-
- // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果
- frequencyOfType.put(classKey, typeOfThis * classOfAll);
- }
-
- return frequencyOfType;
- }
-
- public void pringAll() {
- Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap
- .entrySet();
- for (Entry<String, Map<String, Double>> ent : classWordEntry) {
- System.out.println("类别: " + ent.getKey());
- Map<String, Double> wordMap = ent.getValue();
- Set<Entry<String, Double>> wordMapSet = wordMap.entrySet();
- for (Entry<String, Double> wordEnt : wordMapSet) {
- System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());
- }
- }
- }
-
- public Map<String, String> getClassMap() {
- return classMap;
- }
-
- public void setClassMap(Map<String, String> classMap) {
- this.classMap = classMap;
- }
-
- }
在试验过程中,发觉某篇文章的分类不太准,某篇IT文章分到招聘类别下了,在仔细对比了训练数据后,发觉这是由于招聘类别每篇文章下面都带有“搜狗”的标志,而待分类的这篇IT文章里面充斥这搜狗这类词汇,结果招聘类下的概率比较大。由此想到,在除了做常规的贝叶斯计算时,需要把不同文本中出现次数多的词汇权重降低甚至删除(好比关键词搜索中的tf-idf),通俗点讲就是,在所有训练文本中某词汇(如的,地,得)出现的次数越多,这个词越不重要,比如IT文章中“软件”和“应用”这两个词汇,“应用”应该是很多文章类别下都有的,反而不太重要,但是“软件”这个词汇大多只出现在IT文章里,出现在大量文章的概率并不大。 我这里原本打算计算每个词的idf,然后给定一个阀值来判断是否需要纳入计算,但是由于词汇太多,计算量较大(等待结果时间较长),所以暂时注释掉了。
- /**
- * 训练器
- *
- * @author duyf
- *
- */
- class Train implements Serializable {
- /**
- *
- */
- private static final long serialVersionUID = 1L;
- public final static String SERIALIZABLE_PATH = "D:\\workspace\\Test\\SogouC.mini\\Sample\\Train.ser";
- // 训练集的位置
- private String trainPath = "D:\\workspace\\Test\\SogouC.mini\\Sample";
- // 类别序号对应的实际名称
- private Map<String, String> classMap = new HashMap<String, String>();
- // 类别对应的txt文本数
- private Map<String, Integer> classP = new ConcurrentHashMap<String, Integer>();
- // 所有文本数
- private AtomicInteger actCount = new AtomicInteger(0);
- // 每个类别对应的词典和频数
- private Map<String, Map<String, Double>> classWordMap = new ConcurrentHashMap<String, Map<String, Double>>();
- // 分词器
- private transient Participle participle;
- private static Train trainInstance = new Train();
- public static Train getInstance() {
- trainInstance = new Train();
- // 读取序列化在硬盘的本类对象
- FileInputStream fis;
- try {
- File f = new File(SERIALIZABLE_PATH);
- if (f.length() != 0) {
- fis = new FileInputStream(SERIALIZABLE_PATH);
- ObjectInputStream oos = new ObjectInputStream(fis);
- trainInstance = (Train) oos.readObject();
- trainInstance.participle = new IkParticiple();
- } else {
- trainInstance = new Train();
- }
- } catch (Exception e) {
- e.printStackTrace();
- }
- return trainInstance;
- }
- private Train() {
- this.participle = new IkParticiple();
- }
- public String readtxt(String path) {
- BufferedReader br = null;
- StringBuilder str = null;
- try {
- br = new BufferedReader(new FileReader(path));
- str = new StringBuilder();
- String r = br.readLine();
- while (r != null) {
- str.append(r);
- r = br.readLine();
- }
- return str.toString();
- } catch (IOException ex) {
- ex.printStackTrace();
- } finally {
- if (br != null) {
- try {
- br.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- }
- str = null;
- br = null;
- }
- return "";
- }
- /**
- * 训练数据
- */
- public void realTrain() {
- // 初始化
- classMap = new HashMap<String, String>();
- classP = new HashMap<String, Integer>();
- actCount.set(0);
- classWordMap = new HashMap<String, Map<String, Double>>();
- // classMap.put("C000007", "汽车");
- classMap.put("C000008", "财经");
- classMap.put("C000010", "IT");
- classMap.put("C000013", "健康");
- classMap.put("C000014", "体育");
- classMap.put("C000016", "旅游");
- classMap.put("C000020", "教育");
- classMap.put("C000022", "招聘");
- classMap.put("C000023", "文化");
- classMap.put("C000024", "军事");
- // 计算各个类别的样本数
- Set<String> keySet = classMap.keySet();
- // 所有词汇的集合,是为了计算每个单词在多少篇文章中出现,用于后面计算df
- final Set<String> allWords = new HashSet<String>();
- // 存放每个类别的文件词汇内容
- final Map<String, List<String[]>> classContentMap = new ConcurrentHashMap<String, List<String[]>>();
- for (String classKey : keySet) {
- Participle participle = new IkParticiple();
- Map<String, Double> wordMap = new HashMap<String, Double>();
- File f = new File(trainPath + File.separator + classKey);
- File[] files = f.listFiles(new FileFilter() {
- @Override
- public boolean accept(File pathname) {
- if (pathname.getName().endsWith(".txt")) {
- return true;
- }
- return false;
- }
- });
- // 存储每个类别的文件词汇向量
- List<String[]> fileContent = new ArrayList<String[]>();
- if (files != null) {
- for (File txt : files) {
- String content = readtxt(txt.getAbsolutePath());
- // 分词
- String[] word_arr = participle.participle(content, false);
- fileContent.add(word_arr);
- // 统计每个词出现的个数
- for (String word : word_arr) {
- if (wordMap.containsKey(word)) {
- Double wordCount = wordMap.get(word);
- wordMap.put(word, wordCount + 1);
- } else {
- wordMap.put(word, 1.0);
- }
- }
- }
- }
- // 每个类别对应的词典和频数
- classWordMap.put(classKey, wordMap);
- // 每个类别的文章数目
- classP.put(classKey, files.length);
- actCount.addAndGet(files.length);
- classContentMap.put(classKey, fileContent);
- }
- // 把训练好的训练器对象序列化到本地 (空间换时间)
- FileOutputStream fos;
- try {
- fos = new FileOutputStream(SERIALIZABLE_PATH);
- ObjectOutputStream oos = new ObjectOutputStream(fos);
- oos.writeObject(this);
- } catch (Exception e) {
- e.printStackTrace();
- }
- }
- /**
- * 分类
- *
- * @param text
- * @return 返回各个类别的概率大小
- */
- public Map<String, Double> classify(String text) {
- // 分词,并且去重
- String[] text_words = participle.participle(text, false);
- Map<String, Double> frequencyOfType = new HashMap<String, Double>();
- Set<String> keySet = classMap.keySet();
- for (String classKey : keySet) {
- double typeOfThis = 1.0;
- Map<String, Double> wordMap = classWordMap.get(classKey);
- for (String word : text_words) {
- Double wordCount = wordMap.get(word);
- int articleCount = classP.get(classKey);
- /*
- * Double wordidf = idfMap.get(word); if(wordidf==null){
- * wordidf=0.001; }else{ wordidf = Math.log(actCount / wordidf); }
- */
- // 假如这个词在类别下的所有文章中木有,那么给定个极小的值 不影响计算
- double term_frequency = (wordCount == null) ? ((double) 1 / (articleCount + 1))
- : (wordCount / articleCount);
- // 文本在类别的概率 在这里按照特征向量独立统计,即概率=词汇1/文章数 * 词汇2/文章数 。。。
- // 当double无限小的时候会归为0,为了避免 *10
- typeOfThis = typeOfThis * term_frequency * 10;
- typeOfThis = ((typeOfThis == 0.0) ? Double.MIN_VALUE
- : typeOfThis);
- // System.out.println(typeOfThis+" : "+term_frequency+" :
- // "+actCount);
- }
- typeOfThis = ((typeOfThis == 1.0) ? 0.0 : typeOfThis);
- // 此类别文章出现的概率
- double classOfAll = classP.get(classKey) / actCount.doubleValue();
- // 根据贝叶斯公式 $(A|B)=S(B|A)*S(A)/S(B),由于$(B)是常数,在这里不做计算,不影响分类结果
- frequencyOfType.put(classKey, typeOfThis * classOfAll);
- }
- return frequencyOfType;
- }
- public void pringAll() {
- Set<Entry<String, Map<String, Double>>> classWordEntry = classWordMap
- .entrySet();
- for (Entry<String, Map<String, Double>> ent : classWordEntry) {
- System.out.println("类别: " + ent.getKey());
- Map<String, Double> wordMap = ent.getValue();
- Set<Entry<String, Double>> wordMapSet = wordMap.entrySet();
- for (Entry<String, Double> wordEnt : wordMapSet) {
- System.out.println(wordEnt.getKey() + ":" + wordEnt.getValue());
- }
- }
- }
- public Map<String, String> getClassMap() {
- return classMap;
- }
- public void setClassMap(Map<String, String> classMap) {
- this.classMap = classMap;
- }
- }