这里的困难在于,由于身高、体重、脚掌都是连续变量,不能采用离散变量的方法计算概率。而且由于样本太少,所以也无法分成区间计算。怎么办? 这时,可以假设男性和女性的身高、体重、脚掌都是正态分布,通过样本计算出均值和方差,也就是得到正态分布的密度函数。有了密度函数,就可以把值代入,算出某一点的密度函数的值。 import java.math.BigDecimal; import java.math.MathContext; import java.math.RoundingMode; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; /** * @author xjh * @date 2020-10-29-13:19 */ public class Bayes { private static List<List<String>> datas; static { datas = new ArrayList<>(); List<String> source1 = new ArrayList<>(); source1.add("男"); source1.add("6"); source1.add("180"); source1.add("12"); datas.add(source1); List<String> source2 = new ArrayList<>(); source2.add("男"); source2.add("5.92"); source2.add("190"); source2.add("11"); datas.add(source2); List<String> source3 = new ArrayList<>(); source3.add("男"); source3.add("5.58"); source3.add("170"); source3.add("12"); datas.add(source3); List<String> source4 = new ArrayList<>(); source4.add("男"); source4.add("5.92"); source4.add("165"); source4.add("10"); datas.add(source4); List<String> source5 = new ArrayList<>(); source5.add("女"); source5.add("5"); source5.add("100"); source5.add("6"); datas.add(source5); List<String> source6 = new ArrayList<>(); source6.add("女"); source6.add("5.5"); source6.add("150"); source6.add("8"); datas.add(source6); List<String> source7 = new ArrayList<>(); source7.add("女"); source7.add("5.42"); source7.add("130"); source7.add("7"); datas.add(source7); List<String> source8 = new ArrayList<>(); source8.add("女"); source8.add("5.75"); source8.add("150"); source8.add("9"); datas.add(source8); } //朴素贝叶斯公式 // P(c\w)=P(cw)/P(w) //p(c\w) = p(w\c)p(c)/p(w) //p(性别\身高,体重,脚掌)=p(身高,体重,脚掌\性别)p(性别)/p(身高,体重,脚掌) // =p(身高\性别)p(体重\性别)p(脚掌\性别)p(性别)/p(身高)p(体重)p(脚掌) // p(女|6,130,8) = p(6|女)p(130|女)p(8|女)p(女)/p(6)p(130)p(8) // p(男|6,130,8) = p(6|男)p(130|男)p(8|男)p(男)/p(6)p(130)p(8) public static String bayesCheckSex(List<String> characteristics){ Map<String,List<String>> map = new HashMap<>(); for(int a = 0;a<datas.size();a++){ for(int b = 0;b<datas.get(a).size();b++){ //身高 if(b==1){ String key = datas.get(a).get(0)+"_身高"; diegst(map,a,b,key); } //体重 if(b==2){ String key = datas.get(a).get(0)+"_体重"; diegst(map,a,b,key); } //脚掌 if(b==3){ String key = datas.get(a).get(0)+"_脚掌"; diegst(map,a,b,key); } } } //密度函数 Map<String,BigDecimal> map1 = new HashMap<>(); for(Map.Entry<String,List<String>> entity:map.entrySet()){ List<String> values = entity.getValue(); BigDecimal total = BigDecimal.ZERO; for(String value:values){ total = total.add(new BigDecimal(value)); } //平均值 BigDecimal avg = total.divide(new BigDecimal(values.size())); //方差分子 BigDecimal total1 = BigDecimal.ZERO; for(String value:values){ total1 = total1.add(new BigDecimal(value).subtract(avg).multiply(new BigDecimal(value).subtract(avg))); } //方差 BigDecimal u = total1.divide(new BigDecimal(values.size())); BigDecimal source = BigDecimal.ZERO; if(entity.getKey().contains("身高")){ source = new BigDecimal(characteristics.get(0)); }if(entity.getKey().contains("体重")){ source = new BigDecimal(characteristics.get(1)); }if(entity.getKey().contains("脚掌")){ source = new BigDecimal(characteristics.get(2)); } //密度函数 // 均值为:μ,方差为:σ²,那么对应的概率密度函数为: // f(x) = [1/√(2π)] exp{-(x-μ)²/2σ²} int a = source.subtract(avg).multiply(source.subtract(avg)).negate().divide(new BigDecimal(BigDecimal.ROUND_CEILING).multiply(u), 4, BigDecimal.ROUND_HALF_UP).intValue(); //a^(-n)=1/(a^n),比如:3^(-2)=1/3^2=1/9 BigDecimal exp = BigDecimal.ZERO; if(a>0){ exp = new BigDecimal(Math.E).pow(a); }else { exp = new BigDecimal(BigDecimal.ROUND_DOWN).divide(new BigDecimal(Math.E).pow(Math.abs(a)), 10, BigDecimal.ROUND_HALF_UP); } BigDecimal density = new BigDecimal(BigDecimal.ROUND_DOWN).divide(sqrt(new BigDecimal(BigDecimal.ROUND_CEILING) .multiply(new BigDecimal(Math.PI)).multiply(u), 4), 4, BigDecimal.ROUND_HALF_UP) .multiply(exp); map1.put(entity.getKey(),density); } BigDecimal boy = map1.get("男_身高").multiply(map1.get("男_体重")).multiply(map1.get("男_脚掌")); BigDecimal girl = map1.get("女_身高").multiply(map1.get("女_体重")).multiply(map1.get("女_脚掌")); if(girl.compareTo(boy)>0){ return "女"; }else { return "男"; } } public static Map<String,List<String>> diegst(Map<String,List<String>> map ,int a,int b,String key){ if(map.containsKey(key)){ List<String> list = map.get(key); list.add(datas.get(a).get(b)); map.put(key,list); }else { List<String> list = new ArrayList<>(); list.add(datas.get(a).get(b)); map.put(key,list); } return map; } //开平方 public static BigDecimal sqrt(BigDecimal value, int scale){ BigDecimal num2 = BigDecimal.valueOf(2); int precision = 100; MathContext mc = new MathContext(precision, RoundingMode.HALF_UP); BigDecimal deviation = value; int cnt = 0; while (cnt < precision) { deviation = (deviation.add(value.divide(deviation, mc))).divide(num2, mc); cnt++; } deviation = deviation.setScale(scale, BigDecimal.ROUND_HALF_UP); return deviation; } public static void main(String[] args) { List<String> characteristics = new ArrayList<>(); characteristics.add("6"); characteristics.add("130"); characteristics.add("8"); System.out.println(bayesCheckSex(characteristics)); } }
java 朴素贝叶斯算法
最新推荐文章于 2023-07-03 13:10:14 发布