文本情感分析之朴素贝叶斯

一、贝叶斯理论

学过概率的同学一定都知道贝叶斯定理:

这个在250多年前发明的算法,在信息领域内有着无与伦比的地位。贝叶斯分类是一系列分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类。朴素贝叶斯算法(Naive Bayesian) 是其中应用最为广泛的分类算法之一。

朴素贝叶斯分类器基于一个简单的假定:给定目标值时属性之间相互条件独立。

通过以上定理和“朴素”的假定,我们知道:

P( Category | Document) = P ( Document | Category ) * P( Category) / P(Document)

我们现在有一个数据集,它由两类数据组成,数据分布如下图所示:

我们现在用 p1(x,y) 表示数据点 (x,y) 属于类别 1(图中用圆点表示的类别)的概率,用 p2(x,y) 表示数据点 (x,y) 属于类别 2(图中三角形表示的类别)的概率,那么对于一个新数据点 (x,y),可以用下面的规则来判断它的类别:

  • 如果 p1(x,y) > p2(x,y) ,那么类别为1
  • 如果 p2(x,y) > p1(x,y) ,那么类别为2

也就是说,我们会选择高概率对应的类别。这就是贝叶斯决策理论的核心思想,即选择具有最高概率的决策。

 

在文档分类中,整个文档(如一封电子邮件)是实例,而电子邮件中的某些元素则构成特征。我们可以观察文档中出现的词,并把每个词作为一个特征,而每个词的出现或者不出现作为该特征的值,这样得到的特征数目就会跟词汇表中的词的数目一样多。

我们假设特征之间 相互独立 。所谓 独立(independence) 指的是统计意义上的独立,即一个特征或者单词出现的可能性与它和其他单词相邻没有关系,比如说,我们中的出现的概率与这两个字相邻没有任何关系。这个假设正是朴素贝叶斯分类器中 朴素(naive) 一词的含义。朴素贝叶斯分类器中的另一个假设是,每个特征同等重要

Note: 朴素贝叶斯分类器通常有两种实现方式: 一种基于伯努利模型实现,一种基于多项式模型实现。这里采用前一种实现方式。该实现方式中并不考虑词在文档中出现的次数,只考虑出不出现,因此在这个意义上相当于假设词是等权重的。

三、朴素贝叶斯 场景

机器学习的一个重要应用就是文档的自动分类。

在文档分类中,整个文档(如一封电子邮件)是实例,而电子邮件中的某些元素则构成特征。我们可以观察文档中出现的词,并把每个词作为一个特征,而每个词的出现或者不出现作为该特征的值,这样得到的特征数目就会跟词汇表中的词的数目一样多。

朴素贝叶斯是上面介绍的贝叶斯分类器的一个扩展,是用于文档分类的常用算法。下面我们会进行一些朴素贝叶斯分类的实践项目。

朴素贝叶斯 原理

朴素贝叶斯 工作原理

提取所有文档中的词条并进行去重

获取文档的所有类别

计算每个类别中的文档数目

对每篇训练文档:

    对每个类别:

        如果词条出现在文档中-->增加该词条的计数值(for循环或者矩阵相加)

        增加所有词条的计数值(此类别下词条总数)

对每个类别:

    对每个词条:

        将该词条的数目除以总词条数目得到的条件概率(P(词条|类别)

返回该文档属于每个类别的条件概率(P(类别|文档的所有词条)

朴素贝叶斯 开发流程

收集数据: 可以使用任何方法。

准备数据: 需要数值型或者布尔型数据。

分析数据: 有大量特征时,绘制特征作用不大,此时使用直方图效果更好。

训练算法: 计算不同的独立特征的条件概率。

测试算法: 计算错误率。

使用算法: 一个常见的朴素贝叶斯应用是文档分类。可以在任意的分类场景中使用朴素贝叶斯分类器,不一定非要是文本。

朴素贝叶斯 算法特点

优点: 在数据较少的情况下仍然有效,可以处理多类别问题。

缺点: 对于输入数据的准备方式较为敏感。

适用数据类型: 标称型数据。

 

项目案例(文章得好评度)

有以下数据集,前面为每行数据格式为:类别\t词语\s词语\s词语.......,如下格式。

附上数据集地址:https://pan.baidu.com/s/1WETILqblEil4-5Mv6jHxSA

利用Java实现单机版朴素贝叶斯分类器,如下代码:

package com.hadoop;

import java.io.*;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;

public class NBClassifier {
	private static String trainingDataFilePath = "F:/train_data/training-1000.txt";
    private static String modelFilePath = "F:/parameters-1000.model";
    private static String testDataFilePath = "F:/train_data/test-1000.txt";
    private static String outputFilePath = "F:/预测结果.result";

    public static String[] extractFeatures(String sentence) {
        /******* 添加句子按空格分割取特征字符串数组的代码 *******/
    	String[] split = sentence.split(" ");
        return split;
    }
    //从文本文件中训练出模型
    public static void train() throws Exception {
        HashMap<String, Integer> parameters = new HashMap<String, Integer>();

        /******* 添加“类别\t计数”和“类别-特征\t计数”统计代码 *******/
        BufferedReader br = new BufferedReader(new FileReader(trainingDataFilePath));
        String sentence = null;
        while ( null != (sentence = br.readLine()) ) {
			String[] content = sentence.split("\t| ");
			parameters.put(content[0], parameters.getOrDefault(content[0], 0)+1);
			for (int i = 1; i < content.length; i++) {
				parameters.put(content[0]+"-"+content[i],         parameters.getOrDefault(content[0]+"-"+content[i], 0)+1);
			}
		}
        br.close();
        saveModel(parameters);
    }

    private static void saveModel(HashMap<String, Integer> parameters) {
        Iterator<String> keyIter = parameters.keySet().iterator();
        BufferedWriter bw = null;

        try {
            bw = new BufferedWriter(new FileWriter(modelFilePath));
        } catch (IOException e) {
            e.printStackTrace();
        }

        while (keyIter.hasNext()) {
            String key = keyIter.next();
            int value = parameters.get(key);

            try {
                bw.append(key + "\t" + value + "\r\n");
            } catch (IOException e) {
                e.printStackTrace();
            }
        }

        try {
            bw.flush();
            bw.close();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static HashMap<String, Integer> parameters = null;
    private static Set<String> V = null;
    private static double Nd;
    private static double sizeOfV;
    //加载训练的模型(key value)
    public static void loadModel() {
        V = new HashSet<String>();
        parameters = new HashMap<String, Integer>();

        try {
            List<String> parameterData = Files.readAllLines(Paths.get(modelFilePath));

            for (int i = 0; i < parameterData.size(); i++) {
                String parameter = parameterData.get(i);
                String key = parameter.substring(0, parameter.indexOf("\t"));
                Integer value = Integer.parseInt(parameter.substring(parameter.indexOf("\t") + 1));

                parameters.put(key, value);

                if (key.contains("-")) {
                    String feature = key.substring(key.indexOf("-") + 1);

                    V.add(feature);
                }

                if (!key.contains("-")) {
                    Nd += value;
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
    //实现关键
    public static String predict(String sentence) {
        String[] labels = {"好评", "差评"};
        String[] features = extractFeatures(sentence);

        double maxProb = Double.NEGATIVE_INFINITY;
        String prediction = null;

        /******* 添加预测模型的代码实现 *******/
        Map<String, Double> prior = totalcatagory(parameters);
        double good = Math.log( prior.get(labels[0]) / prior.get("Nd") ) + likehood(parameters, labels[0], features);
        double bad = Math.log( prior.get(labels[1]) / prior.get("Nd") ) + likehood(parameters, labels[1], features);
        if (good>bad) {
        	prediction = labels[0];
        	maxProb = good;
		}else if (good<bad) {
			prediction = labels[1];
			maxProb = bad;
		}else {
			prediction = "无法预测";
		}
        return prediction;
    }
    //计算模型中类别的总数
    public static Map<String, Double> totalcatagory(Map<String,Integer> parameters) {
    	double good = 0.0;
    	double bad = 0.0;
    	double restotal = 0.0;
    	HashMap<String, Double> hashMap = new HashMap<>();
    	for (Map.Entry<String, Integer> map : parameters.entrySet()) {
			String key = map.getKey();
			Integer value = map.getValue();
			if (key.contains("好评-")) {
				good += value;
			}else if (key.contains("差评-")) {
				bad += value;
			}
		}
    	
    	restotal = good + bad;
    	hashMap.put("好评", good);
    	hashMap.put("差评", bad);
    	hashMap.put("Nd", restotal);
    	return hashMap;
    }
    //似然概率的计算
    public static double likehood(Map<String, Integer> parameters,String catagory,String[] features) {
    	double p=0.0;
    	Map<String, Double> totalcatagory = totalcatagory(parameters);
        //分母平滑处理
    	Double V = totalcatagory.get(catagory) + 1;
    	for (String word : features) {
            //分子平滑处理
    		Integer Nc = parameters.getOrDefault(catagory+"-"+word, 0) + 1;
			p += Math.log(   Nc/V  );
		}
    	return p;
    }

    public static void predictAll() {
    	double accuracy = 0.;
        int amount = 0;
        BufferedWriter bw = null;
        try {
            bw = new BufferedWriter(new FileWriter(outputFilePath));
        } catch (IOException e) {
            e.printStackTrace();
        }
        try {
            List<String> testData = Files.readAllLines(Paths.get(testDataFilePath));
            for (String instance : testData) {
            	
            	String gold = instance.substring(0, instance.indexOf("\t"));
                String sentence = instance.substring(instance.indexOf("\t") + 1);
                String prediction = predict(sentence); 

                System.out.println("Gold='" + gold + "'\tPrediction='" + prediction + "'");
                // 将prediction按行输出到文件中的结果
                bw.append(prediction+"\r\n");
                if (gold.equals(prediction)) {
                    accuracy += 1.;
                }
                amount += 1;
            }
            bw.close();
        } catch (Exception e) {
            e.printStackTrace();
        }
        //打印好评率
        System.out.println("Accuracy = " + accuracy / amount);
    }
    public static void main(String[] args) throws Exception {
    	
    	//training
    	train();
    	//test
        loadModel();
        //预测文档
        predictAll();
        
    }
}

训练模型如下:

在以上的训练模型中,预测测试集文本(预测每行数据中,类别后的语句为好评或者差评)的结果为:

总结:

以上就是用Java实现的朴素贝叶斯分类器。如果文本数据很大,那么就可以考虑使用Hadoop的MapReduce来实现训练模型,然后核心算法不变。后面我会使用MapReduce实现贝叶斯分类器。

 

  • 6
    点赞
  • 44
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值