1、Bayes.java
package com.bayes;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
/**
* <p>
* 本类描述:
* 本类主要实现朴素贝叶斯分类,因p(ci/x) = p(x/ci)*p(ci)/p(x), 所有仅需最大化p(x/ci)*p(ci),
* 便能求出x具体属于哪个类别
* </p>
* <p>
* 主要步骤:
* 步骤1: 计算p(ci)
* 步骤2: 计算p(xi/ci)
* 步骤3: 计算p(x/ci)和p(x/ci)*p(ci)
* 步骤3: 找出p(x/ci)*p(ci)的最大值,得出分类结论
* </p>
* @author Wang Haiyang
* @date 2015-6-25 上午09:10:17
*/
public class Bayes {
/** 存储每个类别 */
private static List<String> classifys = new ArrayList<String>();
/** 存储每个类别的个数 */
private static Map<String, Integer> classCount = new HashMap<String, Integer>();
/** 存储每个类别的概率 */
private static Map<String, Double> classProb = new HashMap<String, Double>();
/** 存储每个属性的条件概率 */
private static List<ConditionProb> conditionProbs = new ArrayList<ConditionProb>();
/** 存储每个类别的联合条件概率 */
private static List<Double> unionProbs = new ArrayList<Double>();
private static List<Double> lastProbs = new ArrayList<Double>();
public static void main(String[] args) {
// 初始化数据
List<ArrayList<String>> datas = configData();
List<String> attributes = configAttribute();
List<String> tests = configTestSample();
getClassify(datas, attributes, tests);
}
/**
* 方法描述:得到类标号
* @param datas
* @param attributes
* @param tests
*/
private static void getClassify(List<ArrayList<String>> datas, List<String> attributes, List<String> tests) {
// 计算每个类别的概率
computeClassProb(datas);
// 计算每个属性的条件概率
computeAttrProb(datas, attributes, tests);
// 计算每个类别的联合条件概率
computeUnionProb();
// 找出类标号
Double max = Collections.max(lastProbs);
String last = classifys.get(lastProbs.indexOf(max));
System.out.println(last);
}
/**
* 方法描述:计算每个类别的联合条件概率
*/
private static void computeUnionProb() {
for (String classify : classifys) {
Double union = 1D;
for (ConditionProb cp : conditionProbs) {
if (cp.getClassValue().equals(classify)) {
union *= cp.getProbility();
}
}
unionProbs.add(union);
}
for (int i = 0; i < classifys.size(); i++) {
lastProbs.add(unionProbs.get(i) * classProb.get(classifys.get(i)));
}
}
/**
* 方法描述:计算每个属性的条件概率
* @param datas
* @param attributes
* @param tests
*/
private static void computeAttrProb(List<ArrayList<String>> datas, List<String> attributes, List<String> tests) {
for(int i = 0; i < tests.size(); i++) {
for (Entry<String, Integer> entry : classCount.entrySet()) { // 计算每个属性和每个类的概率
Map<String, Integer> testCount = new HashMap<String, Integer>();
for (ArrayList<String> lists : datas) { // 计算每个属性的个数
if(lists.get(i).equals(tests.get(i)) && lists.get(lists.size() - 1).equals(entry.getKey())) {
Integer value = testCount.get(lists.get(i));
testCount.put(lists.get(i), value == null ? 1 : ++value);
}
}
ConditionProb cps = new ConditionProb();
cps.setAttributeName(attributes.get(i));
cps.setAttributeValue(tests.get(i));
cps.setClassValue(entry.getKey());
Double probility = Double.parseDouble(String.valueOf(testCount.get(tests.get(i)))) / Double.parseDouble(String.valueOf(entry.getValue()));;
cps.setProbility(probility );
conditionProbs.add(cps);
}
}
}
/**
* 方法描述:计算每个类别的概率
* @param datas
*/
private static void computeClassProb(List<ArrayList<String>> datas) {
for (ArrayList<String> lists : datas) {
Integer value = classCount.get(lists.get(lists.size() - 1));
classCount.put(lists.get(lists.size() - 1), value == null ? 1 : ++value);
}
for (Entry<String, Integer> entry : classCount.entrySet()) {
classifys.add(entry.getKey());
Double value = Double.parseDouble(String.valueOf(entry.getValue())) / Double.parseDouble(String.valueOf(datas.size()));
classProb.put(entry.getKey(), value);
}
}
private static List<String> configTestSample() {
List<String> results = new ArrayList<String>();
results.add("youth");
results.add("medium");
results.add("yes");
results.add("fair");
return results;
}
private static List<String> configAttribute() {
List<String> results = new ArrayList<String>();
results.add("age");
results.add("income");
results.add("student");
results.add("credit_rating");
return results;
}
private static List<ArrayList<String>> configData() {
List<ArrayList<String>> results = new ArrayList<ArrayList<String>>();
try {
BufferedReader is = new BufferedReader(new InputStreamReader(new FileInputStream(new File("D:/data.txt"))));
String line = is.readLine();
while (line != null) {
String[] split = line.split(",");
ArrayList<String> s1s = new ArrayList<String>();
for (int i = 0; i < split.length; i++) {
s1s.add(split[i]);
}
results.add(s1s);
line = is.readLine();
}
} catch (FileNotFoundException e) {
e.printStackTrace();
} catch (IOException e) {
e.printStackTrace();
}
return results;
}
}
2、ConditionProb.java
package com.bayes;
/**
* <p>本类描述: 主要用来存放条件概率</p>
* <p>其他说明: </p>
* @author Wang Haiyang
* @date 2015-6-25 上午11:32:12
*/
public class ConditionProb {
/** 属性名 eg: age*/
private String attributeName;
/** 属性值 eg: youth*/
private String attributeValue;
/** 分类值,eg: yes, no */
private String classValue;
/** 概率 */
private Double probility;
public String getAttributeName() {
return attributeName;
}
public void setAttributeName(String attributeName) {
this.attributeName = attributeName;
}
public String getAttributeValue() {
return attributeValue;
}
public void setAttributeValue(String attributeValue) {
this.attributeValue = attributeValue;
}
public String getClassValue() {
return classValue;
}
public void setClassValue(String classValue) {
this.classValue = classValue;
}
public Double getProbility() {
return probility;
}
public void setProbility(Double probility) {
this.probility = probility;
}
}