话说最近要写个文本分类的项目,然后嵌套到系统里面去,打算用spark,发现rdd并不好存储,自己写了个来实现吧,
原理主要参考:
http://blog.csdn.net/cxmscb/article/details/69267326
http://blog.163.com/jiayouweijiewj@126/blog/static/1712321772010102802635243/
代码用到的数据:
Chinese,Beijing,Chinese,yes
Chinese,Chinese,Shanghai,yes
Chinese,Macao,yes
Tokyo,Japan,Chinese,no
其中yes no是标签 ,看代码:
package com.meituan.model.learn;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.ArrayUtils;
public class Learn {
// Chinese,Chinese,Chinese,Tokyo,Japan
public static String path = "/Users/shuubiasahi/Desktop/bayies/bayies.txt";
public Map<String, Integer> totalMap = new HashMap<String, Integer>();
public Map<String, Integer> yesMap = new HashMap<String, Integer>();
public Map<String, Integer> noMap = new HashMap<String, Integer>();
public static double alpha = 1.0;
public Set<String> set = new HashSet<String>();
private BufferedReader buff;
public void initWithMU(String path) throws IOException {
buff = new BufferedReader(new InputStreamReader(
new FileInputStream(path)));
String text = buff.readLine();
while (text != null) {
String[] texts = text.split("\\,");
int len = texts.length - 1;
String label = texts[len].trim();
if ("yes".equalsIgnoreCase(label)) {
for (int i = 0; i < len; i++) {
set.add(texts[i]);
if (yesMap.get(texts[i]) == null) {
yesMap.put(texts[i], 1);
} else {
yesMap.put(texts[i], yesMap.get(texts[i]) + 1);
}
}
}
if ("no".equalsIgnoreCase(label)) {
for (int i = 0; i < len ; i++) {
set.add(texts[i]);
if (noMap.get(texts[i]) == null) {
noMap.put(texts[i], 1);
} else {
noMap.put(texts[i], noMap.get(texts[i]) + 1);
}
}
}
if (totalMap.get(label) == null) {
totalMap.put(label, len);
} else {
totalMap.put(label, totalMap.get(label) + len);
}
if (totalMap.get("total") == null) {
totalMap.put("total", len);
} else {
totalMap.put("total", totalMap.get("total") + len);
}
text = buff.readLine();
}
}
public String trainNBWithMU(String text, double alpha) {
String[] texts = text.split("\\,");
double yesP = 0.0;
double noP = 0.0;
int yesTotal = 0;
int noTotal = 0;
double yesTotalP = Math.log(totalMap.get("yes") * 1.0 / totalMap.get("total"));
double noTotalP = Math.log(totalMap.get("no") * 1.0 / totalMap.get("total"));
for (Integer y : yesMap.values()) {
yesTotal += y;
}
for (Integer n : noMap.values()) {
noTotal += n;
}
for (int i = 0; i < texts.length ; i++) {
int temp=0;
if(yesMap.get(texts[i])!=null){
temp=yesMap.get(texts[i]);
}
yesP +=Math.log( 1.0 * (temp+ alpha) / (yesTotal+alpha*set.size()));
}
for (int i = 0; i < texts.length ; i++) {
int temp=0;
if(noMap.get(texts[i])!=null){
temp=noMap.get(texts[i]);
}
noP+=Math.log( 1.0* (temp+ alpha) /( noTotal+alpha*set.size()));
}
if ((yesTotalP + yesP) > (noTotalP + noP)) {
return "yes";
} else {
return "no";
}
}
public void initWithBO(String path) throws IOException{
buff = new BufferedReader(new InputStreamReader(
new FileInputStream(path)));
String text = buff.readLine();
while (text != null) {
String[] textsToSet = text.split("\\,");
String label = textsToSet[textsToSet.length-1].trim();
Set<String> setTemp=new HashSet(Arrays.asList(ArrayUtils.remove(textsToSet, textsToSet.length-1)));
Object[] texts = setTemp.toArray();
int len = texts.length;
if ("yes".equalsIgnoreCase(label)) {
for (int i = 0; i < len; i++) {
set.add((String)texts[i]);
if (yesMap.get(texts[i]) == null) {
yesMap.put((String)texts[i], 1);
} else {
yesMap.put((String)texts[i], yesMap.get(texts[i]) + 1);
}
}
}
if ("no".equalsIgnoreCase(label)) {
for (int i = 0; i < len ; i++) {
set.add((String)texts[i]);
if (noMap.get((String)texts[i]) == null) {
noMap.put((String)texts[i], 1);
} else {
noMap.put((String)texts[i], noMap.get(texts[i]) + 1);
}
}
}
if(totalMap.get(label)==null){
totalMap.put(label, 1);
}else{
totalMap.put(label, totalMap.get(label)+1);
}
if(totalMap.get("total")==null){
totalMap.put("total", 1);
}else{
totalMap.put(label, totalMap.get("total")+1);
}
text = buff.readLine();
}
}
public String trainNBWithBO(String text, double alpha) {
String[] texts = text.split("\\,");
Set<String> setTemp=new HashSet(Arrays.asList(texts));
double yesP = 0.0;
double noP = 0.0;
int yesTotal = 0;
int noTotal =0;
double yesTotalP = Math.log(totalMap.get("yes") * 1.0 / totalMap.get("total"));
double noTotalP = Math.log(totalMap.get("no") * 1.0 / totalMap.get("total"));
for (Integer y : yesMap.values()) {
yesTotal += y;
}
for (Integer n : noMap.values()) {
noTotal += n;
}
for (String s:setTemp) {
int temp=0;
if(yesMap.get(s)!=null){
temp=yesMap.get(s);
}
yesP +=Math.log( 1.0 * (temp+ alpha) / (yesTotal+alpha*set.size()));
}
for (String s:setTemp) {
int temp=0;
if(noMap.get(s)!=null){
temp=noMap.get(s);
}
noP+=Math.log( 1.0* (temp+ alpha) /( noTotal+alpha*set.size()));
}
System.out.println("yes:"+(yesTotalP + yesP));
System.out.println("no :"+(noTotalP + noP));
if ((yesTotalP + yesP) > (noTotalP + noP)) {
return "yes";
} else {
return "no";
}
}
public static void main(String[] args) throws IOException {
Learn learn=new Learn();
learn.initWithBO(Learn.path);
System.out.println(learn.trainNBWithBO("Chinese,Chinese,Chinese,Tokyo,Japan",Learn.alpha));
System.out.println(Math.log(0.005));
System.out.print(Math.log(0.022));
}
}