贝叶斯分类器

贝叶斯


贝叶斯公式:

P(A|C)=P(C|A)P(A)P(C)
事件A在事件C发生的概率为事件C在A发生下的概率乘以事件A发生的概率,最后除上事件C发生的概率

经典场景

射击问题

A,B两人射击,A有50%的概率命中,B有60%概率命中,已知目标被命中,求分别为A、B的概率。

令目标被命中事件为C,则有:
由求贝叶斯公式可得:
P(C)=0.50.6+0.50.6+0.50.4=0.8

P(A|C)=P(C|A)P(A)P(C)=58

同理可得 P(B|C)=34

医疗检测

已知条件如下:
1. 人口统计先验有:
得癌症的概率: P(ω1)=0.008
不得癌症概率: P(ω2)=0.992
2. 医疗检测中:
阳性:
P(+|ω1)=0.98
P(+|ω2)=0.02
阴性:
P(|ω2)=0.97
P(|ω1)=0.03
那么当一次检测为阳性时,得癌症的概率有多大?
P(+)=P(ω1)P(+|ω1)+P(ω2)P(+|ω2)=0.1948
P(ω1|+)=P(+|ω1)P(ω1)P(+)=0.28
当第二次检测为阳性时,得癌症的概率为多少?
这里的计算过程不变,但是先验概率 P(ω1) 改变了,为0.28,所以要重新计算 P(+)

用贝叶斯做分类

推导过程

1.开始公式

ωmap=argmaxωiωP(ωi|a1,a2,a3..an)
其中, ai 为其中的属性。整个公式的解释是:这条数据的最终类别是 ωi 在条件{a_的概率最大的那个分类

2.用贝叶斯公式

ωmap=argmaxωiωP(a1,a2...an|ωi)P(ωi)P(a1,a2...an)

3.化简

去掉 P(a1,a2...an) ,因为每个都一样
其中 P(ωi) 是可以从训练集中统计出来的先验概率

4.引入独立条件

P(a1,a2...an|ωi)=P(a1|ωi)P(a2|ωi)...P(an|ωi)

5.最终可得到公式

ωmap=argmaxωiωP(ωi)jP(aj|ωi)
输入数据就是 aj ,也就是各个属性的值
在数据集中可以获得的数据有:
1. P(ωi)=
2. P(aj|ωi)=ωiajωi

决策树

ID3

定义:
Entropy(S)=i=1cpilog(pi)
解释:
1. S是最后的标签属性,取值范围为c
2. pi=i

信息增益

在上节中,只是计算了当前集合的总体熵,信息增益=总体熵-(用标签外的属性X来划分之后的熵)
Gain(S,X)=Entropy(S)Entropy(S|X)
Entropy(S|X)=vX|Sv||S|Entropy(Sv)
Entropy(Sv)=i=1cpilog(pi) 这个样本公式的 Sv 代表属性 X=v 的所有属性局
例子:

id是否抽烟头发长度鞋码性别(男|女)
1false100mid
2true100small
3true10big
4false20mid
5true30mid
6true70big
7false100small
8false50small

1. 总体熵:
p()=3/8
p()=5/8
Entropy()=p()log2p()p()log2p()=0.96
2. 计算 Entropy(|=true) :
p(=|=true)=0.5
p(=|=true)=0.5
Entropy(|=true)=1.0
3. 计算 Entropy(|=false)
p(=|=false)=1/4
p(=|=false)=3/4
Entropy(|=false)=14log21434log234=0.81
4. 计算 Entropy(|)
p(=true)=0.5
p(=false)=0.5
Entropy(|)=0.51.0+0.50.81=0.905
5. 信息增益
Gain(,)=0.960.905=0.095

利用信息熵

分别计算各个属性的信息增益,去最大的那个属性作为节点label

过拟合

两个分类器A、B,A在训练集中的效果比B好,但是在测试集中比B差,我们说A过拟合。

限制决策树高度

剪枝

将两个叶子节点,合并后,按照少数服从多数得出label
需要增设一个校验集,用于剪枝过程中的误差比较。
当剪枝进行到在校验集上误差由减小到增大的拐点时,停止剪枝

处理连续性数据

采用信息增益衡量按照进行对阈值切分点后的数据集的纯度,采用信息增益比较大的。

贝叶斯分类器实现

package com.liuyanzuo.datamining.classification;

import java.util.*;

/**
 * 朴素贝叶斯分类器实现
 * Created by tempuser on 2017/1/19.
 */
public class NaiveBayesClassification {
    //定义常量
    public static final String NOT_DEFINE_ATTR="not build the attributeList";
    public static final String SUCCESS="success";


    //定义存储类别信息的结构
    private Map<String,Map<String,Map<String,Integer>>> statisticsMsg;
    //定义属性名称集合
    private List<String> attributeList;
    //每个label的数量统计
    private Map<String,Integer> labelCountMap;
    //每个属性可取值的范围
    private Map<String,List<String>> attrValue;
    //每个label的每个属性的百分比统计
    private Map<String,Map<String,Map<String,Double>>> labelAttrPercentMap;
    //label在属性的下标
    private int labelIndex;
    //数据的总数量
    private int totalCount;

    /**
     * 构造分类器
     * @param data
     * @param labelIndex
     */
    public String build(List<List<String>> data,int labelIndex){
        if(null==attributeList || "".equals(attributeList)){
            return NOT_DEFINE_ATTR;
        }
        this.labelIndex=labelIndex;
        //初始化各个属性
        statisticsMsg=new HashMap<>();
        labelCountMap=new HashMap<>();
        attrValue=new HashMap<>();


        for(List<String> attributeLabelList : data){
            //这行数据的标签
            String label=attributeLabelList.get(labelIndex);
            //统计这行数据的label
            Integer labelPercentValue=labelCountMap.get(label);
            if(labelPercentValue==null){
                labelPercentValue=0;
            }
            labelPercentValue++;
            labelCountMap.put(label,labelPercentValue);
            totalCount++;

            Map<String,Map<String,Integer>> labelMap= statisticsMsg.get(label);
            if(null == labelMap){
                labelMap=new HashMap<>();
                statisticsMsg.put(label,labelMap);
            }

            for(int i=0;i<attributeLabelList.size();i++){
                if(i != labelIndex){
                    //现在所在下标的属性名称
                    String attributeName=attributeList.get(i);
                    //现在所在下标的属性值
                    String attributeValue=attributeLabelList.get(i);
                    //统计属性的取值范围
                    List<String> attrValueList=attrValue.get(attributeName);
                    if(attrValueList==null){
                        attrValueList=new ArrayList<>();
                    }
                    if(!attrValueList.contains(attributeValue)){
                        attrValueList.add(attributeValue);
                    }
                    attrValue.put(attributeName,attrValueList);

                    Map<String,Integer> attributeMap=labelMap.get(attributeName);
                    if( null == attributeMap){
                        attributeMap=new HashMap<>();
                        labelMap.put(attributeList.get(i),attributeMap);
                    }
                    Integer attributeCountValue=attributeMap.get(attributeValue);
                    if(null==attributeCountValue){
                        attributeCountValue=0;
                    }
                    attributeCountValue++;
                    attributeMap.put(attributeValue,attributeCountValue);
                }
            }
        }
        labelAttrPercentMap=new HashMap<>();
        //统计label百分比
        Set<String> labelSet=statisticsMsg.keySet();
        for(String label:labelSet){
            //这个label的总长度
            int labelCount=labelCountMap.get(label);
            Map<String,Map<String,Integer>> statisticsLabelAttrMap=statisticsMsg.get(label);
            //统计每个label下的各个属性的各个取值的数量
            Map<String,Map<String,Double>> percentValue=new HashMap<>();
            Set<String> attrSet=statisticsLabelAttrMap.keySet();
            for(String attribute:attrSet){
                Map<String,Integer> attributeValueMap=statisticsLabelAttrMap.get(attribute);
                Set<String> attributeValueSet=attributeValueMap.keySet();
                Map<String,Double> percentAttributeValueMap=new HashMap<>();
                for(String attributeValue:attributeValueSet){
                    //最终属性取值的百分比
                    percentAttributeValueMap.put(attributeValue,attributeValueMap.get(attributeValue)/(labelCount*1.0));
                }
                percentValue.put(attribute,percentAttributeValueMap);
            }
            labelAttrPercentMap.put(label,percentValue);
        }
        return SUCCESS;
    }

    /**
     * 对传入数据进行分类
     * @param needClassify
     */
    public Map<String,Double> classify(List<String> needClassify){
        Map<String,Double> result=new HashMap<>();
        if(null == statisticsMsg || statisticsMsg.size()==0){
            return result;
        }
        for(String label:labelCountMap.keySet()){
            double prediction=1.0;
            for(int i=0;i<attributeList.size();i++){
                //当前属性名称
                String attrName=attributeList.get(i);

                if(i != labelIndex){
                    //要做一个laplace平滑
                    Integer labelAttrPercentValue=statisticsMsg.get(label).get(attributeList.get(i)).get(needClassify.get(i));
                    if(labelAttrPercentValue==null){
                        labelAttrPercentValue=0;
                    }
                    prediction*=(labelAttrPercentValue+1.0)/(attrValue.get(attrName).size()*1.0+labelCountMap.get(label)*1.0);
                }
            }
            result.put(label,prediction);
        }
        return result;
    }
    public Map<String, Map<String, Map<String, Integer>>> getStatisticsMsg() {
        return statisticsMsg;
    }

    public void setStatisticsMsg(Map<String, Map<String, Map<String, Integer>>> statisticsMsg) {
        this.statisticsMsg = statisticsMsg;
    }
    public List<String> getAttributeList() {
        return attributeList;
    }

    public void setAttributeList(List<String> attributeList) {
        this.attributeList = attributeList;
    }

    public Map<String, Integer> getLabelCountMap() {
        return labelCountMap;
    }

    public Map<String, Map<String, Map<String, Double>>> getLabelAttrPercentMap() {
        return labelAttrPercentMap;
    }

    public void setLabelAttrPercentMap(Map<String, Map<String, Map<String, Double>>> labelAttrPercentMap) {
        this.labelAttrPercentMap = labelAttrPercentMap;
    }
}
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值