Naive Bayes 朴素贝叶斯(文本)分类器Java实现

  • 算法原理推导
  • 伪代码
  • java实现代码
  • 测试数据

1. 算法原理推导

1.1 优缺点分析

优点:在数据较少的情况下,仍然有效,可以处理多分类问题
缺点:对于输入数据的准备方式比较敏感
适用数据类型:标称型数据
###主要思想
p1(x,y)表示数据点(x,y)属于类别1的概率;
p2(x,y)表示数据点(x,y)属于类别2的概率;
if:
p1>p2 属于1类;
else
属于2类

1.2 假设条件

  • 特征之间相互独立
    • 一个特征或者单词出现的可能性与它或其他单词相邻没有关系
  • 每个特征同等重要

1.3 文本分类实现方式

  • 基于伯努利模型(不考虑词出现次数)
  • 给予多项式模型

2. 伪代码

//C,类别集合,D,用于训练的文本文件集合
TrainMultiNomialNB(C,D) {
    // 单词出现多次,只算一个
    V←ExtractVocabulary(D)
    // 单词可重复计算
    N←CountTokens(D)
    for each c∈C
        // 计算类别c下的单词总数
        // N和Nc的计算方法和Introduction to Information Retrieval上的不同,个人认为
        //该书是错误的,先验概率和类条件概率的计算方法应当保持一致
        Nc←CountTokensInClass(D,c)
        prior[c]←Nc/N
        // 将类别c下的文档连接成一个大字符串
        textc←ConcatenateTextOfAllDocsInClass(D,c)
        for each t∈V
            // 计算类c下单词t的出现次数
            Tct←CountTokensOfTerm(textc,t)
        for each t∈V
            //计算P(t|c)
            condprob[t][c]return V,prior,condprob
}

ApplyMultiNomialNB(C,V,prior,condprob,d) {
    // 将文档d中的单词抽取出来,允许重复,如果单词是全新的,在全局单词表V中都
    // 没出现过,则忽略掉
    W←ExtractTokensFromDoc(V,d)
    for each c∈C
        score[c]←prior[c]
        for each t∈W
            if t∈Vd
                score[c] *= condprob[t][c]
    return max(score[c])

}
/************************************************************************/
//C,类别集合,D,用于训练的文本文件集合
TrainBernoulliNB(C, D) {
    // 单词出现多次,只算一个
V←ExtractVocabulary(D)
    // 计算文件总数
    N←CountDocs(D)
    for each c∈C
        // 计算类别c下的文件总数
        Nc←CountDocsInClass(D,c)
        prior[c]←Nc/N
        for each t∈V
            // 计算类c下包含单词t的文件数
            Nct←CountDocsInClassContainingTerm(D,c,t)
            //计算P(t|c)
            condprob[t][c](Nct+1)/(Nct+2)
    return V,prior,condprob
}

ApplyBernoulliNB(C,V,prior,condprob,d) {
    // 将文档d中单词表抽取出来,如果单词是全新的,在全局单词表V中都没出现过,
    // 则舍弃
    Vd←ExtractTermsFromDoc(V,d)
    for each c∈C
        score[c]←prior[c]
        for each t∈V
            if t∈Vd
                score[c] *= condprob[t][c]
            else
                score[c] *= (1-condprob[t][c])
    return max(score[c])
}

3. java代码实现

3.1 主函数

/**
 * 
 */

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import com.utils.MathM;

/**
 * @author Home
 * 
 */
public class NBMain {

	/**
	 * @param args
	 */
	final static String dataPath = "data.csv";
	
	static List<String[]> dataList = new ArrayList<String[]>(); //训练集词集合
	static List<float[]> vectorList = new ArrayList<float[]>(); //训练集词向量
	static List<String> vocabList = new ArrayList<String>(); //词典
	static float[] trainCategory; //训练集类别trainCategory=[0,1,0,1,0,1]
	static int numTrainDocs = 0; //训练集文本数量
	static int numwords = 0; //词典size
	
	static MathM mm = new MathM();
	public static void main(String[] args) throws IOException {	
		
		//加载数据集
		vocabList = loadDataSet(dataPath);
		Model model = trainBayes();
		System.out.println(Arrays.toString(model.p0Vect));
		System.out.println(Arrays.toString(model.p1Vect));
		String[] test1 = {"love","my","dalmation"};
		String[] test2 = {"stupid","garbage"};
		System.out.println(classifyNB(setofWords2Vec(vocabList, test1,test1.length),model));
		System.out.println(classifyNB(setofWords2Vec(vocabList, test2,test2.length),model));

	}
	
	public static int classifyNB(float[] vec2Classify,Model model){
		double p1 = mm.multiply(vec2Classify,model.p1Vect)+Math.log(model.pAbusive);
		double p0 = mm.multiply(vec2Classify,model.p0Vect)+Math.log(1-model.pAbusive);
		if(p1>p0)
			return 1;
		else
			return 0;
	}

	public static List<String> loadDataSet(String dataPath) throws IOException {
		BufferedReader br = null;
		String line;
		try {
			br = new BufferedReader(new FileReader(new File(dataPath)));
			while ((line = br.readLine()) != null) {
				String[] info = line.split(",");
				dataList.add(info);
			}
			br.close();
		} catch (FileNotFoundException e) {
			e.printStackTrace();
		}
		numTrainDocs = dataList.size();
		List<String> vocabList = new ArrayList<String>();
		trainCategory = new float[dataList.size()];
		int j = 0;
		for (String[] str : dataList) {
			for (int i = 0; i < str.length - 1; i++)
				if (!vocabList.contains(str[i]))
					vocabList.add(str[i]);
			trainCategory[j] = Integer.parseInt(str[str.length - 1]);
			j++;
		}
		Collections.sort(vocabList);

		
		for (String[] str : dataList) {
			float[] temp = setofWords2Vec(vocabList, str,str.length-1);
			vectorList.add(temp);
		}
		numwords = vocabList.size();
		return vocabList;
	}
	public static float[] setofWords2Vec(List<String> vocabList,String[] postingDoc,int n){
		//根据词典和词转化为词向量(onehot编码)
		float[] temp = new float[vocabList.size()];
		int index = -1;
		for (int i = 0; i <n; i++) {
			index = vocabList.indexOf(postingDoc[i]);
			temp[index] = 1.0f;
		}

		return temp;
	}
	
	/*
	 * 1.首先计算属于侮辱性文档(class=1)的概率,即p(c1);p(c0) = 1-p(c1); 
	 * 2.计算p(wi|c1)以及p(wi|c0)
	 * 
	 * List<String[]> dataList,List<Integer[]> vectorList
	 */

	public static Model trainBayes() {

		float pAbusive = (float) (mm.sum(trainCategory) / numTrainDocs);
		float[] p0Num = new float[numwords];
		float[] p1Num = new float[numwords];
		Arrays.fill(p0Num, 1);
		Arrays.fill(p1Num, 1);
		//因为很多词出现次数为0,为使概率不为0.将所有词初始化为1,分母初始化为2
		float p0Denom = 2.0f;
		float p1Denom = 2.0f;
		for (int i = 0; i < numTrainDocs; i++) {
			float[] temp = vectorList.get(i);
			if (trainCategory[i] == 1) {
				p1Num = mm.dot(p1Num, temp);
				p1Denom += mm.sum(temp);    
				
			} else {
				p0Num = mm.dot(p0Num, temp);
				p0Denom += mm.sum(temp);
			}
		}
		//下溢出:由于p(w|c)很小,相乘为下溢出或得不到正确答案。方法对乘积取自然对数(ln(a*b) = ln(a)+ln(b))
		float[] p1Vect = mm.fVect(p1Num,p1Denom);	
		float[] p0Vect = mm.fVect(p0Num,p0Denom);
		Model m = new Model(p0Vect,p1Vect,pAbusive);
		return m;
	}
}

3.2 模型构造方法

/**
 *模型构造方法 
 */
package com.loadData;

/**
 * @author Home
 *
 */
public class Model {
	float[] p0Vect;
	float[] p1Vect;
	float pAbusive;
	public float[] getP0Vect() {
		return p0Vect;
	}
	public float[] getP1Vect() {
		return p1Vect;
	}
	public void setP1Vect(float[] p1Vect) {
		this.p1Vect = p1Vect;
	}
	public float getmodel() {
		return pAbusive;
	}
	public void setpAbusive(float pAbusive) {
		this.pAbusive = pAbusive;
	}
	public void setP0Vect(float[] p0Vect) {
		this.p0Vect = p0Vect;
	}
	public Model(float[] p0Vect,float[] p1Vect,float pAbusive) {
		this.p0Vect = p0Vect;
		this.p1Vect = p1Vect;
		this.pAbusive = pAbusive;
	}
	
}

3.3 工具类

package com.utils;
/*
 *自定义方法
 * */

public class MathM {
	
	public  float sum(float[] R) {
		float sum = 0;
		for (float i : R)
			sum += i;
		return sum;
	}

	public  float[] dot(float A[], float B[]) {
		float C[] = new float[A.length];
		for (int i = 0; i < C.length; i++)
			C[i] = A[i] + B[i];
		return C;
	}

	public  float[] fVect(float[] A, float pDenom) {
		float[] fvect = new float[A.length];
		for (int i = 0; i < A.length; i++) {
			fvect[i] = (float) Math.log(A[i] / pDenom);//A[i] / pDenom;//
		}
		return fvect;
	}
	public  double multiply(float[] A,float[] B){
		double C = 0;
		for(int i=0;i<A.length;i++){
			C+= A[i]*B[i];
		}
		return C;
			
	}
}

4. 数据集

data.csv

my,dog,has,flea,problem,help,please,0
maybe,not,take,him,to,dog,park,stupid,1
my,dalmation,is,so,cute,I,love,him,0
stop,posting,stupid,worthless,garbage,1
me,licks,ate,my,steak,how,to,stop,him,0
quit,buying,worthless,dog,food,stupid,1

参考文献:《机器学习实战》
在这里插入图片描述

发布了26 篇原创文章 · 获赞 19 · 访问量 5万+
展开阅读全文

没有更多推荐了,返回首页

©️2019 CSDN 皮肤主题: 大白 设计师: CSDN官方博客

分享到微信朋友圈

×

扫一扫,手机浏览