java 朴素贝叶斯算法

这里的困难在于,由于身高、体重、脚掌都是连续变量,不能采用离散变量的方法计算概率。而且由于样本太少,所以也无法分成区间计算。怎么办?
这时,可以假设男性和女性的身高、体重、脚掌都是正态分布,通过样本计算出均值和方差,也就是得到正态分布的密度函数。有了密度函数,就可以把值代入,算出某一点的密度函数的值。

import java.math.BigDecimal;
import java.math.MathContext;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
 * @author xjh
 * @date 2020-10-29-13:19
 */
public class Bayes {
    private static List<List<String>> datas;
    static {
        datas = new ArrayList<>();
        List<String> source1 = new ArrayList<>();
        source1.add("男");
        source1.add("6");
        source1.add("180");
        source1.add("12");
        datas.add(source1);
        List<String> source2 = new ArrayList<>();
        source2.add("男");
        source2.add("5.92");
        source2.add("190");
        source2.add("11");
        datas.add(source2);
        List<String> source3 = new ArrayList<>();
        source3.add("男");
        source3.add("5.58");
        source3.add("170");
        source3.add("12");
        datas.add(source3);
        List<String> source4 = new ArrayList<>();
        source4.add("男");
        source4.add("5.92");
        source4.add("165");
        source4.add("10");
        datas.add(source4);
        List<String> source5 = new ArrayList<>();
        source5.add("女");
        source5.add("5");
        source5.add("100");
        source5.add("6");
        datas.add(source5);
        List<String> source6 = new ArrayList<>();
        source6.add("女");
        source6.add("5.5");
        source6.add("150");
        source6.add("8");
        datas.add(source6);
        List<String> source7 = new ArrayList<>();
        source7.add("女");
        source7.add("5.42");
        source7.add("130");
        source7.add("7");
        datas.add(source7);
        List<String> source8 = new ArrayList<>();
        source8.add("女");
        source8.add("5.75");
        source8.add("150");
        source8.add("9");
        datas.add(source8);
    }
    //朴素贝叶斯公式
   // P(c\w)=P(cw)/P(w)
    //p(c\w) = p(w\c)p(c)/p(w)
    //p(性别\身高,体重,脚掌)=p(身高,体重,脚掌\性别)p(性别)/p(身高,体重,脚掌)
    //                        =p(身高\性别)p(体重\性别)p(脚掌\性别)p(性别)/p(身高)p(体重)p(脚掌)
    //          p(女|6,130,8) = p(6|女)p(130|女)p(8|女)p(女)/p(6)p(130)p(8)
    //          p(男|6,130,8) = p(6|男)p(130|男)p(8|男)p(男)/p(6)p(130)p(8)
    public static String bayesCheckSex(List<String> characteristics){
        Map<String,List<String>> map = new HashMap<>();
        for(int a = 0;a<datas.size();a++){
            for(int b = 0;b<datas.get(a).size();b++){
                //身高
                if(b==1){
                   String key = datas.get(a).get(0)+"_身高";
                   diegst(map,a,b,key);
                }
                //体重
                if(b==2){
                    String key = datas.get(a).get(0)+"_体重";
                    diegst(map,a,b,key);
                }
                //脚掌
                if(b==3){
                    String key = datas.get(a).get(0)+"_脚掌";
                    diegst(map,a,b,key);
                }
            }
        }


        //密度函数
         Map<String,BigDecimal> map1 = new HashMap<>();
        for(Map.Entry<String,List<String>> entity:map.entrySet()){
            List<String> values = entity.getValue();
            BigDecimal total =  BigDecimal.ZERO;
            for(String value:values){
                total = total.add(new BigDecimal(value));
            }
            //平均值
            BigDecimal avg = total.divide(new BigDecimal(values.size()));

            //方差分子
            BigDecimal total1 =  BigDecimal.ZERO;
            for(String value:values){
                total1 = total1.add(new BigDecimal(value).subtract(avg).multiply(new BigDecimal(value).subtract(avg)));
            }
            //方差
            BigDecimal u = total1.divide(new BigDecimal(values.size()));


            BigDecimal source = BigDecimal.ZERO;
            if(entity.getKey().contains("身高")){
                source = new BigDecimal(characteristics.get(0));
            }if(entity.getKey().contains("体重")){
                source = new BigDecimal(characteristics.get(1));
            }if(entity.getKey().contains("脚掌")){
                source = new BigDecimal(characteristics.get(2));
            }
            //密度函数
            // 均值为:μ,方差为:σ²,那么对应的概率密度函数为:
            // f(x) = [1/√(2π)] exp{-(x-μ)²/2σ²}
            int a = source.subtract(avg).multiply(source.subtract(avg)).negate().divide(new BigDecimal(BigDecimal.ROUND_CEILING).multiply(u), 4, BigDecimal.ROUND_HALF_UP).intValue();
            //a^(-n)=1/(a^n),比如:3^(-2)=1/3^2=1/9
            BigDecimal exp = BigDecimal.ZERO;
            if(a>0){
                exp = new BigDecimal(Math.E).pow(a);
            }else {
                exp = new BigDecimal(BigDecimal.ROUND_DOWN).divide(new BigDecimal(Math.E).pow(Math.abs(a)), 10, BigDecimal.ROUND_HALF_UP);
            }
            BigDecimal density = new BigDecimal(BigDecimal.ROUND_DOWN).divide(sqrt(new BigDecimal(BigDecimal.ROUND_CEILING)
                    .multiply(new BigDecimal(Math.PI)).multiply(u), 4), 4, BigDecimal.ROUND_HALF_UP)
                    .multiply(exp);
            map1.put(entity.getKey(),density);
        }

        BigDecimal boy = map1.get("男_身高").multiply(map1.get("男_体重")).multiply(map1.get("男_脚掌"));
        BigDecimal girl = map1.get("女_身高").multiply(map1.get("女_体重")).multiply(map1.get("女_脚掌"));

        if(girl.compareTo(boy)>0){
            return "女";
        }else {
            return "男";
        }
    }



    public static Map<String,List<String>>  diegst(Map<String,List<String>> map ,int a,int b,String key){
        if(map.containsKey(key)){
            List<String> list = map.get(key);
            list.add(datas.get(a).get(b));
            map.put(key,list);
        }else {
            List<String> list = new ArrayList<>();
            list.add(datas.get(a).get(b));
            map.put(key,list);
        }
        return map;
    }


    //开平方
    public static BigDecimal sqrt(BigDecimal value, int scale){
        BigDecimal num2 = BigDecimal.valueOf(2);
        int precision = 100;
        MathContext mc = new MathContext(precision, RoundingMode.HALF_UP);
        BigDecimal deviation = value;
        int cnt = 0;
        while (cnt < precision) {
            deviation = (deviation.add(value.divide(deviation, mc))).divide(num2, mc);
            cnt++;
        }
        deviation = deviation.setScale(scale, BigDecimal.ROUND_HALF_UP);
        return deviation;
    }

    public static void main(String[] args) {
        List<String> characteristics = new ArrayList<>();
        characteristics.add("6");
        characteristics.add("130");
        characteristics.add("8");
        System.out.println(bayesCheckSex(characteristics));
    }

}
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值