一、贝叶斯理论
学过概率的同学一定都知道贝叶斯定理:
这个在250多年前发明的算法,在信息领域内有着无与伦比的地位。贝叶斯分类是一系列分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类。朴素贝叶斯算法(Naive Bayesian) 是其中应用最为广泛的分类算法之一。
朴素贝叶斯分类器基于一个简单的假定:给定目标值时属性之间相互条件独立。
通过以上定理和“朴素”的假定,我们知道:
P( Category | Document) = P ( Document | Category ) * P( Category) / P(Document)
我们现在有一个数据集,它由两类数据组成,数据分布如下图所示:
我们现在用 p1(x,y) 表示数据点 (x,y) 属于类别 1(图中用圆点表示的类别)的概率,用 p2(x,y) 表示数据点 (x,y) 属于类别 2(图中三角形表示的类别)的概率,那么对于一个新数据点 (x,y),可以用下面的规则来判断它的类别:
- 如果 p1(x,y) > p2(x,y) ,那么类别为1
- 如果 p2(x,y) > p1(x,y) ,那么类别为2
也就是说,我们会选择高概率对应的类别。这就是贝叶斯决策理论的核心思想,即选择具有最高概率的决策。
在文档分类中,整个文档(如一封电子邮件)是实例,而电子邮件中的某些元素则构成特征。我们可以观察文档中出现的词,并把每个词作为一个特征,而每个词的出现或者不出现作为该特征的值,这样得到的特征数目就会跟词汇表中的词的数目一样多。
我们假设特征之间 相互独立 。所谓 独立(independence) 指的是统计意义上的独立,即一个特征或者单词出现的可能性与它和其他单词相邻没有关系,比如说,“我们”中的“我”和“们”出现的概率与这两个字相邻没有任何关系。这个假设正是朴素贝叶斯分类器中 朴素(naive) 一词的含义。朴素贝叶斯分类器中的另一个假设是,每个特征同等重要。
Note: 朴素贝叶斯分类器通常有两种实现方式: 一种基于伯努利模型实现,一种基于多项式模型实现。这里采用前一种实现方式。该实现方式中并不考虑词在文档中出现的次数,只考虑出不出现,因此在这个意义上相当于假设词是等权重的。
三、朴素贝叶斯 场景
机器学习的一个重要应用就是文档的自动分类。
在文档分类中,整个文档(如一封电子邮件)是实例,而电子邮件中的某些元素则构成特征。我们可以观察文档中出现的词,并把每个词作为一个特征,而每个词的出现或者不出现作为该特征的值,这样得到的特征数目就会跟词汇表中的词的数目一样多。
朴素贝叶斯是上面介绍的贝叶斯分类器的一个扩展,是用于文档分类的常用算法。下面我们会进行一些朴素贝叶斯分类的实践项目。
朴素贝叶斯 原理
朴素贝叶斯 工作原理
提取所有文档中的词条并进行去重
获取文档的所有类别
计算每个类别中的文档数目
对每篇训练文档:
对每个类别:
如果词条出现在文档中-->增加该词条的计数值(for循环或者矩阵相加)
增加所有词条的计数值(此类别下词条总数)
对每个类别:
对每个词条:
将该词条的数目除以总词条数目得到的条件概率(P(词条|类别))
返回该文档属于每个类别的条件概率(P(类别|文档的所有词条))
朴素贝叶斯 开发流程
收集数据: 可以使用任何方法。
准备数据: 需要数值型或者布尔型数据。
分析数据: 有大量特征时,绘制特征作用不大,此时使用直方图效果更好。
训练算法: 计算不同的独立特征的条件概率。
测试算法: 计算错误率。
使用算法: 一个常见的朴素贝叶斯应用是文档分类。可以在任意的分类场景中使用朴素贝叶斯分类器,不一定非要是文本。
朴素贝叶斯 算法特点
优点: 在数据较少的情况下仍然有效,可以处理多类别问题。
缺点: 对于输入数据的准备方式较为敏感。
适用数据类型: 标称型数据。
项目案例(文章得好评度)
有以下数据集,前面为每行数据格式为:类别\t词语\s词语\s词语.......,如下格式。
附上数据集地址:https://pan.baidu.com/s/1WETILqblEil4-5Mv6jHxSA
利用Java实现单机版朴素贝叶斯分类器,如下代码:
package com.hadoop; import java.io.*; import java.nio.file.Files; import java.nio.file.Paths; import java.util.*; public class NBClassifier { private static String trainingDataFilePath = "F:/train_data/training-1000.txt"; private static String modelFilePath = "F:/parameters-1000.model"; private static String testDataFilePath = "F:/train_data/test-1000.txt"; private static String outputFilePath = "F:/预测结果.result"; public static String[] extractFeatures(String sentence) { /******* 添加句子按空格分割取特征字符串数组的代码 *******/ String[] split = sentence.split(" "); return split; } //从文本文件中训练出模型 public static void train() throws Exception { HashMap<String, Integer> parameters = new HashMap<String, Integer>(); /******* 添加“类别\t计数”和“类别-特征\t计数”统计代码 *******/ BufferedReader br = new BufferedReader(new FileReader(trainingDataFilePath)); String sentence = null; while ( null != (sentence = br.readLine()) ) { String[] content = sentence.split("\t| "); parameters.put(content[0], parameters.getOrDefault(content[0], 0)+1); for (int i = 1; i < content.length; i++) { parameters.put(content[0]+"-"+content[i], parameters.getOrDefault(content[0]+"-"+content[i], 0)+1); } } br.close(); saveModel(parameters); } private static void saveModel(HashMap<String, Integer> parameters) { Iterator<String> keyIter = parameters.keySet().iterator(); BufferedWriter bw = null; try { bw = new BufferedWriter(new FileWriter(modelFilePath)); } catch (IOException e) { e.printStackTrace(); } while (keyIter.hasNext()) { String key = keyIter.next(); int value = parameters.get(key); try { bw.append(key + "\t" + value + "\r\n"); } catch (IOException e) { e.printStackTrace(); } } try { bw.flush(); bw.close(); } catch (IOException e) { e.printStackTrace(); } } private static HashMap<String, Integer> parameters = null; private static Set<String> V = null; private static double Nd; private static double sizeOfV; //加载训练的模型(key value) public static void loadModel() { V = new HashSet<String>(); parameters = new HashMap<String, Integer>(); try { List<String> parameterData = Files.readAllLines(Paths.get(modelFilePath)); for (int i = 0; i < parameterData.size(); i++) { String parameter = parameterData.get(i); String key = parameter.substring(0, parameter.indexOf("\t")); Integer value = Integer.parseInt(parameter.substring(parameter.indexOf("\t") + 1)); parameters.put(key, value); if (key.contains("-")) { String feature = key.substring(key.indexOf("-") + 1); V.add(feature); } if (!key.contains("-")) { Nd += value; } } } catch (IOException e) { e.printStackTrace(); } } //实现关键 public static String predict(String sentence) { String[] labels = {"好评", "差评"}; String[] features = extractFeatures(sentence); double maxProb = Double.NEGATIVE_INFINITY; String prediction = null; /******* 添加预测模型的代码实现 *******/ Map<String, Double> prior = totalcatagory(parameters); double good = Math.log( prior.get(labels[0]) / prior.get("Nd") ) + likehood(parameters, labels[0], features); double bad = Math.log( prior.get(labels[1]) / prior.get("Nd") ) + likehood(parameters, labels[1], features); if (good>bad) { prediction = labels[0]; maxProb = good; }else if (good<bad) { prediction = labels[1]; maxProb = bad; }else { prediction = "无法预测"; } return prediction; } //计算模型中类别的总数 public static Map<String, Double> totalcatagory(Map<String,Integer> parameters) { double good = 0.0; double bad = 0.0; double restotal = 0.0; HashMap<String, Double> hashMap = new HashMap<>(); for (Map.Entry<String, Integer> map : parameters.entrySet()) { String key = map.getKey(); Integer value = map.getValue(); if (key.contains("好评-")) { good += value; }else if (key.contains("差评-")) { bad += value; } } restotal = good + bad; hashMap.put("好评", good); hashMap.put("差评", bad); hashMap.put("Nd", restotal); return hashMap; } //似然概率的计算 public static double likehood(Map<String, Integer> parameters,String catagory,String[] features) { double p=0.0; Map<String, Double> totalcatagory = totalcatagory(parameters); //分母平滑处理 Double V = totalcatagory.get(catagory) + 1; for (String word : features) { //分子平滑处理 Integer Nc = parameters.getOrDefault(catagory+"-"+word, 0) + 1; p += Math.log( Nc/V ); } return p; } public static void predictAll() { double accuracy = 0.; int amount = 0; BufferedWriter bw = null; try { bw = new BufferedWriter(new FileWriter(outputFilePath)); } catch (IOException e) { e.printStackTrace(); } try { List<String> testData = Files.readAllLines(Paths.get(testDataFilePath)); for (String instance : testData) { String gold = instance.substring(0, instance.indexOf("\t")); String sentence = instance.substring(instance.indexOf("\t") + 1); String prediction = predict(sentence); System.out.println("Gold='" + gold + "'\tPrediction='" + prediction + "'"); // 将prediction按行输出到文件中的结果 bw.append(prediction+"\r\n"); if (gold.equals(prediction)) { accuracy += 1.; } amount += 1; } bw.close(); } catch (Exception e) { e.printStackTrace(); } //打印好评率 System.out.println("Accuracy = " + accuracy / amount); } public static void main(String[] args) throws Exception { //training train(); //test loadModel(); //预测文档 predictAll(); } }
训练模型如下:
在以上的训练模型中,预测测试集文本(预测每行数据中,类别后的语句为好评或者差评)的结果为:
总结:
以上就是用Java实现的朴素贝叶斯分类器。如果文本数据很大,那么就可以考虑使用Hadoop的MapReduce来实现训练模型,然后核心算法不变。后面我会使用MapReduce实现贝叶斯分类器。