多项式和伯努利朴素贝叶斯文本分类模型 java实现

话说最近要写个文本分类的项目,然后嵌套到系统里面去,打算用spark,发现rdd并不好存储,自己写了个来实现吧,

原理主要参考:

http://blog.csdn.net/cxmscb/article/details/69267326

http://blog.163.com/jiayouweijiewj@126/blog/static/1712321772010102802635243/



代码用到的数据:

Chinese,Beijing,Chinese,yes
Chinese,Chinese,Shanghai,yes
Chinese,Macao,yes
Tokyo,Japan,Chinese,no


其中yes no是标签 ,看代码:

package com.meituan.model.learn;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.apache.commons.lang.ArrayUtils;
public class Learn {
	// Chinese,Chinese,Chinese,Tokyo,Japan
	public static String path = "/Users/shuubiasahi/Desktop/bayies/bayies.txt";
	public Map<String, Integer> totalMap = new HashMap<String, Integer>();
	public Map<String, Integer> yesMap = new HashMap<String, Integer>();
	public Map<String, Integer> noMap = new HashMap<String, Integer>();
	public static double alpha = 1.0;
	public Set<String> set = new HashSet<String>();
	private BufferedReader buff;

	public void initWithMU(String path) throws IOException {
		buff = new BufferedReader(new InputStreamReader(
				new FileInputStream(path)));
		String text = buff.readLine();
		while (text != null) {
			String[] texts = text.split("\\,");
			int len = texts.length - 1;
			String label = texts[len].trim();
			

			if ("yes".equalsIgnoreCase(label)) {
				
				for (int i = 0; i < len; i++) {
					set.add(texts[i]);
					if (yesMap.get(texts[i]) == null) {
						yesMap.put(texts[i], 1);
					} else {
						yesMap.put(texts[i], yesMap.get(texts[i]) + 1);
					}
				}
			}

			if ("no".equalsIgnoreCase(label)) {
				for (int i = 0; i < len ; i++) {
					set.add(texts[i]);
					if (noMap.get(texts[i]) == null) {
						noMap.put(texts[i], 1);
					} else {
						noMap.put(texts[i], noMap.get(texts[i]) + 1);
					}
				}
			}

			if (totalMap.get(label) == null) {
				totalMap.put(label, len);
			} else {
				totalMap.put(label, totalMap.get(label) + len);
			}

			if (totalMap.get("total") == null) {
				totalMap.put("total", len);
			} else {
				totalMap.put("total", totalMap.get("total") + len);
			}
			text = buff.readLine();
		}

	}

	public String trainNBWithMU(String text, double alpha) {
		String[] texts = text.split("\\,");
		double yesP = 0.0;
		double noP = 0.0;
		int yesTotal = 0;
		int noTotal = 0;
		double yesTotalP = Math.log(totalMap.get("yes") * 1.0 / totalMap.get("total"));
		double noTotalP = Math.log(totalMap.get("no") * 1.0 / totalMap.get("total"));
		for (Integer y : yesMap.values()) {
			yesTotal += y;
		}
		for (Integer n : noMap.values()) {
			noTotal += n;
		}
		for (int i = 0; i < texts.length ; i++) {
			int temp=0;
			if(yesMap.get(texts[i])!=null){
				temp=yesMap.get(texts[i]);
			}
			yesP +=Math.log( 1.0 * (temp+ alpha) / (yesTotal+alpha*set.size()));
		}
		for (int i = 0; i < texts.length ; i++) {
			int temp=0;
			if(noMap.get(texts[i])!=null){
				temp=noMap.get(texts[i]);
			}
			noP+=Math.log( 1.0* (temp+ alpha) /( noTotal+alpha*set.size()));
		}
		if ((yesTotalP + yesP) > (noTotalP + noP)) {
			return "yes";
		} else {
			return "no";
		}
	}
	
	public void initWithBO(String path) throws IOException{
		buff = new BufferedReader(new InputStreamReader(
				new FileInputStream(path)));
		String text = buff.readLine();
		while (text != null) {
			String[] textsToSet = text.split("\\,");
			String label = textsToSet[textsToSet.length-1].trim();
			
			Set<String> setTemp=new HashSet(Arrays.asList(ArrayUtils.remove(textsToSet, textsToSet.length-1)));
			Object[] texts = setTemp.toArray();	
			int len = texts.length;
			if ("yes".equalsIgnoreCase(label)) {
				for (int i = 0; i < len; i++) {
					set.add((String)texts[i]);
					if (yesMap.get(texts[i]) == null) {
						yesMap.put((String)texts[i], 1);
					} else {
						yesMap.put((String)texts[i], yesMap.get(texts[i]) + 1);
					}
				}
			}

			if ("no".equalsIgnoreCase(label)) {
				for (int i = 0; i < len ; i++) {
					set.add((String)texts[i]);
					if (noMap.get((String)texts[i]) == null) {
						noMap.put((String)texts[i], 1);
					} else {
						noMap.put((String)texts[i], noMap.get(texts[i]) + 1);
					}
				}
			}
			
			
			if(totalMap.get(label)==null){
				totalMap.put(label, 1);
			}else{
				totalMap.put(label, totalMap.get(label)+1);
			}
			if(totalMap.get("total")==null){
				totalMap.put("total", 1);
			}else{
				totalMap.put(label, totalMap.get("total")+1);
			}
			 text = buff.readLine();
			
		}
	}
	
	
	public String trainNBWithBO(String text, double alpha) {
		
		String[] texts = text.split("\\,");
		Set<String> setTemp=new HashSet(Arrays.asList(texts));
		double yesP = 0.0;
		double noP = 0.0;
		int yesTotal = 0;
		int noTotal =0;
		double yesTotalP = Math.log(totalMap.get("yes") * 1.0 / totalMap.get("total"));
		double noTotalP = Math.log(totalMap.get("no") * 1.0 / totalMap.get("total"));
		for (Integer y : yesMap.values()) {
			yesTotal += y;
		}
		for (Integer n : noMap.values()) {
			noTotal += n;
		}

		for (String s:setTemp) {
			int temp=0;
			if(yesMap.get(s)!=null){
				temp=yesMap.get(s);
			}
			yesP +=Math.log( 1.0 * (temp+ alpha) / (yesTotal+alpha*set.size()));
		}
		for (String s:setTemp)  {
			int temp=0;
			if(noMap.get(s)!=null){
				temp=noMap.get(s);
			}
			noP+=Math.log( 1.0* (temp+ alpha) /( noTotal+alpha*set.size()));
		}
		System.out.println("yes:"+(yesTotalP + yesP));
		System.out.println("no :"+(noTotalP + noP));
		if ((yesTotalP + yesP) > (noTotalP + noP)) {
			return "yes";
		} else {
			return "no";
		}
	
	}
	
	public static void main(String[] args) throws IOException {
		Learn learn=new Learn();
		learn.initWithBO(Learn.path);
		System.out.println(learn.trainNBWithBO("Chinese,Chinese,Chinese,Tokyo,Japan",Learn.alpha));
		System.out.println(Math.log(0.005));
		System.out.print(Math.log(0.022));
	
	}
}



  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值