多项式朴素贝叶斯文本分类 java

今天可忙活了一天,从写朴素贝叶斯算法到模型存储,到测试数据模评估,

具体来看看代码

我用的数据,用ansj分词,去停顿词 ,最后的文本结构如下:


(-1,技术 特别 特别 神奇 师傅 环境 特别 优雅 下次 再来)
(-1,技师 按摩服务 专业 舒服 太爽啦)
(-1,美女服务 明晚 光顾)
(-1,服务态度)
(-1,谢谢 好吃 嘎嘎嘎嘎)
(-1,服务态度 菜品)
(-1,人漂亮 下次 点赞)
(-1,技术 服务态度)

(1,好吗 好好先生 是从)
(1,期待 坐等 开业)
(1,包房 装修 不错 很大 妹子 热情 主动 这家 夜总会 价格 能接受 不算 有机会)
(1,垃圾 经理 黄祥 老板 弟弟 主管 卖逼 王红 有一 管理方式 不怎么样 垃圾 都是 老板 找人 刷上去 一段时间 饮料 原因 想象 黑暗)
(1,帮忙 白条 急用 帮个忙)
(1,提供 练习室 不在 舞蹈 先用 镜子)


首先是模型训练及保存:

package com.meituan.model.learn;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import com.meituan.model.util.Config;
import com.meituan.model.util.NaiveBaysianBean;

public class Learn {
	private static String trainDataPath = Config
			.getString("model.train.data.path");
	private static String modelSavePath = Config.getString("model.save.path");
	private double alpha = 1.0;
	public Map<String, Integer> totalMap = new HashMap<String, Integer>();
	public Map<String, Integer> positiveMap = new HashMap<String, Integer>();
	public Map<String, Integer> negativeMap = new HashMap<String, Integer>();
	public Set<String> set = new HashSet<String>();
	private BufferedReader buff;

	public Learn() {

	}

	public Learn(double alpha) {
		this.alpha = alpha;
	}

	private void updateMap(Map<String, Integer> map, String elment) {
		if (map.get(elment) == null) {
			map.put(elment, 1);
		} else {
			map.put(elment, map.get(elment) + 1);
		}
	}

	public void trainNBByMU() throws IOException {
		buff = new BufferedReader(new InputStreamReader(new FileInputStream(
				trainDataPath)));
		String text = buff.readLine();
		while (text != null) {
			String[] texts = text.replaceAll("\\(", "").replaceAll("\\)", "")
					.split("\\,");

			String label = texts[0].trim();
			if (texts.length > 1) {
				String[] features = texts[1].split(" ");
				int len = features.length;

				if ("1".equalsIgnoreCase(label)) {
					for (int i = 0; i < len; i++) {
						set.add(features[i]);
						updateMap(positiveMap, features[i]);
					}
				}

				if ("-1".equalsIgnoreCase(label)) {
					for (int i = 0; i < len; i++) {
						set.add(features[i]);
						updateMap(negativeMap, features[i]);
					}
				}

				if (totalMap.get(label) == null) {
					totalMap.put(label, len);
				} else {
					totalMap.put(label, totalMap.get(label) + len);
				}

				if (totalMap.get("total") == null) {
					totalMap.put("total", len);
				} else {
					totalMap.put("total", totalMap.get("total") + len);
				}
			}
			text = buff.readLine();
		}

	}

	

	public static void main(String[] args) throws IOException {
		Learn learn = new Learn();
		learn.trainNBByMU();
		NaiveBaysianBean nv = new NaiveBaysianBean();
		nv.setDim(learn.set.size());
		nv.setTotalMap(learn.totalMap);
		nv.setNegativeMap(learn.negativeMap);
		nv.setPositiveMap(learn.positiveMap);
		ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(
				modelSavePath));
		out.writeObject(nv);
		out.close();
		System.out.println("sucess");

	}
}



模型加载以及预测:


package com.meituan.model.learn;

import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectInputStream;
import java.io.OutputStreamWriter;
import java.util.Map;

import com.meituan.model.util.Config;
import com.meituan.model.util.NaiveBaysianBean;

public class Prediction {
	private static String testDataPath  = Config.getString("model.test.data.path");
	private static String f1measurePath=Config.getString("model.f1.measure");
	private static NaiveBaysianBean nvBaysianBean;

	static {
		try {
			ObjectInputStream in = new ObjectInputStream(new FileInputStream(
					Config.getString("model.save.path")));
			nvBaysianBean = (NaiveBaysianBean) in.readObject();
		} catch (Exception e) {

			e.printStackTrace();
		}
	}

	public static String prediction(String text, double alpha,
			NaiveBaysianBean model) {
		String[] texts = text.split(" ");
		double p = 0.0;
		double n = 0.0;
		int pTotal = 0;
		int nTotal = 0;
		Map<String, Integer> totalMap = model.getTotalMap();
		Map<String, Integer> positiveMap = model.getPositiveMap();
		Map<String, Integer> negativeMap = model.getNegativeMap();
		double pTotalP = Math.log(totalMap.get("1") * 1.0
				/ totalMap.get("total"));
		double nTotalP = Math.log(totalMap.get("-1") * 1.0
				/ totalMap.get("total"));
		for (Integer yValues : positiveMap.values()) {
			pTotal += yValues;
		}
		for (Integer nValues : negativeMap.values()) {
			nTotal += nValues;
		}
		for (int i = 0; i < texts.length; i++) {
			int temp = 0;
			if (positiveMap.get(texts[i]) != null) {
				temp = positiveMap.get(texts[i]);
			}
			p += Math.log(1.0 * (temp + alpha)
					/ (pTotal + alpha * model.getDim()));
		}
		for (int i = 0; i < texts.length; i++) {
			int temp = 0;
			if (negativeMap.get(texts[i]) != null) {
				temp = negativeMap.get(texts[i]);
			}
			n += Math.log(1.0 * (temp + alpha)
					/ (nTotal + alpha * model.getDim()));
		}
		if ((pTotalP + p) > (nTotalP + n)) {
			return "1";
		} else {
			return "-1";
		}

	}
	
	public  static void  writeF1Measure(double alpha) throws IOException{
		BufferedReader  bufferedReader=new BufferedReader(new InputStreamReader(new FileInputStream( testDataPath)));
		BufferedWriter  bufferedWriter=new BufferedWriter(new OutputStreamWriter(new FileOutputStream(f1measurePath)));
		String text=bufferedReader.readLine();
		while(text!=null){
			String[] texts = text.replaceAll("\\(", "").replaceAll("\\)", "")
					.split("\\,");
			
			String label = texts[0].trim();
			if (texts.length > 1) {
				bufferedWriter.write(label+","+prediction(texts[1], alpha, nvBaysianBean)+"\n");
			}
			text=bufferedReader.readLine();
		}
		bufferedWriter.close();
		bufferedReader.close();
		
	}
	
	

	public static void main(String[] args) throws IOException {
		System.out.println(nvBaysianBean.positiveMap.get("微信"));
		writeF1Measure(0.6);
		System.out.println(prediction("微信 校花 过来", 2, nvBaysianBean));

	}

}




测试数据的混淆矩阵情况:

package com.meituan.model.learn;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;


import com.meituan.model.util.Config;

public class F1Measure {

	private float precision; // 精确率(precision)
	private float recall; // 召回率(recall)
	private float fmeasure;

	// 被检索到(Retrieved)
	private int TP;// true positives
	private int FP;// false positives

	// 未被检索到(Not Retrieved)
	private int FN;// false negatives
	private int TN;// true negatives

	private   Map<String, Integer> map = new HashMap<String, Integer>();

	/**
	 * 
	 * @param TP
	 *            Retrieved -> true positives
	 * @param TP
	 *            Retrieved -> FP false positives
	 * @param FN
	 *            Not Retrieved -> false negatives
	 * @param TN
	 *            Not Retrieved -> true negatives
	 */
	public F1Measure(int TP, int FP, int FN, int TN) {

		this.TP = TP;
		this.FP = FP;
		this.FN = FN;
		this.TN = TN;
		this.precision = this.TP * 1.0f / (this.TP + this.FP);
		this.recall = this.TP * 1.0f / (this.TP + this.FN);
		this.fmeasure = 2 * this.precision * this.recall
				/ (this.precision + this.recall);

	}

	private  void updateMap(Map<String, Integer> map, String elment) {
		if (map.get(elment) == null) {
			map.put(elment, 1);
		} else {
			map.put(elment, map.get(elment) + 1);
		}
	}
	

	public F1Measure(String path) {
		BufferedReader br = null;
		try {
			br = new BufferedReader(new FileReader(new File( path)));
			String text = br.readLine();
			while (text != null) {
				String[] texts = text.split("\\,");
				if ("1".equals(texts[0]) && "1".equals(texts[1]) ) {
					updateMap(map, "TP");
				} else if ("1".equals(texts[0]) && "-1".equals(texts[1])) {
					updateMap(map, "FN");

				} else if ( "-1".equals(texts[0]) && "1".equals(texts[1])) {
					updateMap(map, "FP");

				} else if ("-1".equals(texts[0]) && "-1".equals(texts[1] ) ) {
					updateMap(map, "TN");

				}

				text = br.readLine();
			}
		} catch (IOException e) {
			e.printStackTrace();
		} finally {

			try {
				br.close();
			} catch (IOException e) {
				e.printStackTrace();
			}
			System.out.println("map is :"+map);
		}
		this.TP = map.get("TP");
		this.FP = map.get("FP");
		this.FN = map.get("FN");
		this.TN = map.get("TN");
		this.precision = this.TP * 1.0f / (this.TP + this.FP);
		this.recall = this.TP * 1.0f / (this.TP + this.FN);
		this.fmeasure = 2 * this.precision * this.recall
				/ (this.precision + this.recall);
	}

	@Override
	public String toString() {

		String result = "\t预测审核不通过\t\t预测审核通过\n";
		result += "实际审核不通过\t" + this.TP + "\t\t" + this.FP + "\n";
		result += "实际审核通过\t" + this.FN + "\t\t" + this.TN + "\n\n";

		result += "R(召回率):  " + this.recall + "\n";
		result += "P(准确率):  " + this.precision + "\n";
		result += "f-measure:  " + this.fmeasure + "\n";

		return result;
	}

	public float getPrecision() {
		return precision;
	}

	public void setPrecision(float precision) {
		this.precision = precision;
	}

	public float getRecall() {
		return recall;
	}

	public void setRecall(float recall) {
		this.recall = recall;
	}

	public float getFmeasure() {
		return fmeasure;
	}

	public void setFmeasure(float fmeasure) {
		this.fmeasure = fmeasure;
	}

	public int getTP() {
		return TP;
	}

	public void setTP(int tP) {
		TP = tP;
	}

	public int getFP() {
		return FP;
	}

	public void setFP(int fP) {
		FP = fP;
	}

	public int getFN() {
		return FN;
	}

	public void setFN(int fN) {
		FN = fN;
	}

	public int getTN() {
		return TN;
	}

	public void setTN(int tN) {
		TN = tN;
	}

	public static void main(String[] args) {
		String path=Config.getString("model.f1.measure");
		// F1Measure f1 = new F1Measure(20, 30, 0, 50);
		F1Measure f1 = new F1Measure(path);
		System.out.println(f1);
	}

}


最后测试结果:

map is :{FN=1526, FP=734, TN=3336, TP=4714}

预测审核不通过 预测审核通过

实际审核不通过 4714 734

实际审核通过 1526 3336


R(召回率):  0.7554487

P(准确率):  0.8652717

f-measure:  0.80663925




  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值