贝叶斯分类(java)

1、Bayes.java

package com.bayes;


import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;


/**
 * <p>
 *     本类描述: 
 *         本类主要实现朴素贝叶斯分类,因p(ci/x) = p(x/ci)*p(ci)/p(x), 所有仅需最大化p(x/ci)*p(ci),
 *         便能求出x具体属于哪个类别
 * </p>
 * <p>
 *     主要步骤: 
 *         步骤1: 计算p(ci)
 *         步骤2: 计算p(xi/ci)
 *         步骤3: 计算p(x/ci)和p(x/ci)*p(ci)
 *         步骤3: 找出p(x/ci)*p(ci)的最大值,得出分类结论
 * </p>
 * @author Wang Haiyang
 * @date 2015-6-25 上午09:10:17
 */
public class Bayes {
    
    /** 存储每个类别 */
    private static List<String> classifys = new ArrayList<String>();
    
    /** 存储每个类别的个数 */
    private static Map<String, Integer> classCount = new HashMap<String, Integer>();
    
    /** 存储每个类别的概率 */
    private static Map<String, Double> classProb = new HashMap<String, Double>();


    /** 存储每个属性的条件概率 */
    private static List<ConditionProb> conditionProbs = new ArrayList<ConditionProb>();
    
    /** 存储每个类别的联合条件概率 */
    private static List<Double> unionProbs = new ArrayList<Double>();
    
    private static List<Double> lastProbs = new ArrayList<Double>();
    
    public static void main(String[] args) {
        // 初始化数据
        List<ArrayList<String>> datas = configData();
        List<String> attributes = configAttribute();
        List<String> tests = configTestSample();
        getClassify(datas, attributes, tests);
    }


    /**
     * 方法描述:得到类标号
     * @param datas
     * @param attributes
     * @param tests
     */
    private static void getClassify(List<ArrayList<String>> datas, List<String> attributes, List<String> tests) {
        // 计算每个类别的概率
        computeClassProb(datas);
        
        // 计算每个属性的条件概率
        computeAttrProb(datas, attributes, tests);
        
        // 计算每个类别的联合条件概率
        computeUnionProb();
        
       // 找出类标号
        Double max = Collections.max(lastProbs);
        String last = classifys.get(lastProbs.indexOf(max));
        System.out.println(last);
    }


    /**
     * 方法描述:计算每个类别的联合条件概率
     */
    private static void computeUnionProb() {
        for (String classify : classifys) {
            Double union = 1D;
            for (ConditionProb cp : conditionProbs) {
                if (cp.getClassValue().equals(classify)) {
                    union *= cp.getProbility();
                }
            }
            unionProbs.add(union);
        }
        for (int i = 0; i < classifys.size(); i++) {
            lastProbs.add(unionProbs.get(i) * classProb.get(classifys.get(i)));
        }
    }


    /**
     * 方法描述:计算每个属性的条件概率
     * @param datas
     * @param attributes
     * @param tests
     */
    private static void computeAttrProb(List<ArrayList<String>> datas, List<String> attributes, List<String> tests) {
        for(int i = 0; i < tests.size(); i++) {
            for (Entry<String, Integer> entry : classCount.entrySet()) { // 计算每个属性和每个类的概率
                Map<String, Integer> testCount = new HashMap<String, Integer>();
                for (ArrayList<String> lists : datas) { // 计算每个属性的个数
                    if(lists.get(i).equals(tests.get(i)) && lists.get(lists.size() - 1).equals(entry.getKey())) {
                        Integer value = testCount.get(lists.get(i));
                        testCount.put(lists.get(i), value == null ? 1 : ++value);
                    }
                }
                ConditionProb cps = new ConditionProb();
                cps.setAttributeName(attributes.get(i));
                cps.setAttributeValue(tests.get(i));
                cps.setClassValue(entry.getKey());
                Double probility = Double.parseDouble(String.valueOf(testCount.get(tests.get(i)))) / Double.parseDouble(String.valueOf(entry.getValue()));;
                cps.setProbility(probility );
                conditionProbs.add(cps);
            }
        }
    }


    /**
     * 方法描述:计算每个类别的概率
     * @param datas
     */
    private static void computeClassProb(List<ArrayList<String>> datas) {
        for (ArrayList<String> lists : datas) {
            Integer value = classCount.get(lists.get(lists.size() - 1));
            classCount.put(lists.get(lists.size() - 1), value == null ? 1 : ++value);
        }
        for (Entry<String, Integer> entry : classCount.entrySet()) {
            classifys.add(entry.getKey());
            Double value = Double.parseDouble(String.valueOf(entry.getValue())) / Double.parseDouble(String.valueOf(datas.size()));
            classProb.put(entry.getKey(), value);
        }
    }
    
    private static List<String> configTestSample() {
        List<String> results = new ArrayList<String>();
        results.add("youth");
        results.add("medium");
        results.add("yes");
        results.add("fair");
        return results;
    }


    private static List<String> configAttribute() {
        List<String> results = new ArrayList<String>();
        results.add("age");
        results.add("income");
        results.add("student");
        results.add("credit_rating");
        return results;
    }


    private static List<ArrayList<String>> configData() {
        List<ArrayList<String>> results = new ArrayList<ArrayList<String>>();
        try {
            BufferedReader is = new BufferedReader(new InputStreamReader(new FileInputStream(new File("D:/data.txt"))));
            String line = is.readLine();
            while (line != null) {
                String[] split = line.split(",");
                ArrayList<String> s1s = new ArrayList<String>();
                for (int i = 0; i < split.length; i++) {
                    s1s.add(split[i]);
                }
                results.add(s1s);
                line = is.readLine();
            }
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        } catch (IOException e) {
            e.printStackTrace();
        }
        return results;
    }
}

2、ConditionProb.java

package com.bayes;


/**
 * <p>本类描述: 主要用来存放条件概率</p>
 * <p>其他说明: </p>
 * @author Wang Haiyang
 * @date 2015-6-25 上午11:32:12
 */
public class ConditionProb {


    /** 属性名 eg: age*/
    private String attributeName;
    
    /** 属性值 eg: youth*/
    private String attributeValue;
    
    /** 分类值,eg: yes, no */
    private String classValue;
    
    /** 概率 */
    private Double probility;


    public String getAttributeName() {
        return attributeName;
    }


    public void setAttributeName(String attributeName) {
        this.attributeName = attributeName;
    }


    public String getAttributeValue() {
        return attributeValue;
    }


    public void setAttributeValue(String attributeValue) {
        this.attributeValue = attributeValue;
    }


    public String getClassValue() {
        return classValue;
    }


    public void setClassValue(String classValue) {
        this.classValue = classValue;
    }


    public Double getProbility() {
        return probility;
    }


    public void setProbility(Double probility) {
        this.probility = probility;
    }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值