贝叶斯
贝叶斯公式:
P(A|C)=P(C|A)P(A)P(C)
事件A在事件C发生的概率为事件C在A发生下的概率乘以事件A发生的概率,最后除上事件C发生的概率
经典场景
射击问题
A,B两人射击,A有50%的概率命中,B有60%概率命中,已知目标被命中,求分别为A、B的概率。
令目标被命中事件为C,则有:
由求贝叶斯公式可得:
P(C)=0.5∗0.6+0.5∗0.6+0.5∗0.4=0.8
P(A|C)=P(C|A)P(A)P(C)=58
同理可得 P(B|C)=34
医疗检测
已知条件如下:
1. 人口统计先验有:
得癌症的概率:
P(ω1)=0.008
不得癌症概率:
P(ω2)=0.992
2. 医疗检测中:
阳性:
P(+|ω1)=0.98
P(+|ω2)=0.02
阴性:
P(−|ω2)=0.97
P(−|ω1)=0.03
那么当一次检测为阳性时,得癌症的概率有多大?
P(+)=P(ω1)∗P(+|ω1)+P(ω2)∗P(+|ω2)=0.1948
P(ω1|+)=P(+|ω1)P(ω1)P(+)=0.28
当第二次检测为阳性时,得癌症的概率为多少?
这里的计算过程不变,但是先验概率
P(ω1)
改变了,为0.28,所以要重新计算
P(+)
用贝叶斯做分类
推导过程
1.开始公式
ωmap=argmaxωi∈ωP(ωi|a1,a2,a3..an)
其中,
ai
为其中的属性。整个公式的解释是:这条数据的最终类别是
ωi
在条件{a_的概率最大的那个分类
2.用贝叶斯公式
ωmap=argmaxωi∈ωP(a1,a2...an|ωi)P(ωi)P(a1,a2...an)
3.化简
去掉
P(a1,a2...an)
,因为每个都一样
其中
P(ωi)
是可以从训练集中统计出来的先验概率
4.引入独立条件
P(a1,a2...an|ωi)=P(a1|ωi)P(a2|ωi)...P(an|ωi)
5.最终可得到公式
ωmap=argmaxωi∈ωP(ωi)∏jP(aj|ωi)
输入数据就是
aj
,也就是各个属性的值
在数据集中可以获得的数据有:
1.
P(ωi)=每个分类数量总数
2.
P(aj|ωi)=ωi中aj的数量ωi总数
决策树
ID3
熵
定义:
Entropy(S)=−∑i=1cpilog(pi)
解释:
1. S是最后的标签属性,取值范围为c
2.
pi=标签为i的数据数总数据数
信息增益
在上节中,只是计算了当前集合的总体熵,信息增益=总体熵-(用标签外的属性X来划分之后的熵)
Gain(S,X)=Entropy(S)−Entropy(S|X)
Entropy(S|X)=∑v∈X|Sv||S|Entropy(Sv)
Entropy(Sv)=−∑i=1cpilog(pi)
这个样本公式的
Sv
代表属性
X=v
的所有属性局
例子:
id | 是否抽烟 | 头发长度 | 鞋码 | 性别(男|女) |
---|---|---|---|---|
1 | false | 100 | mid | 女 |
2 | true | 100 | small | 女 |
3 | true | 10 | big | 男 |
4 | false | 20 | mid | 男 |
5 | true | 30 | mid | 女 |
6 | true | 70 | big | 男 |
7 | false | 100 | small | 女 |
8 | false | 50 | small | 女 |
1. 总体熵:
p(男)=3/8
p(女)=5/8
Entropy(性别)=−p(男)∗log2p(男)−p(女)∗log2p(女)=0.96
2. 计算
Entropy(性别|抽烟=true)
:
p(性别=男|抽烟=true)=0.5
p(性别=女|抽烟=true)=0.5
Entropy(性别|抽烟=true)=1.0
3. 计算
Entropy(性别|抽烟=false)
p(性别=男|抽烟=false)=1/4
p(性别=女|抽烟=false)=3/4
Entropy(性别|抽烟=false)=−14∗log214−34∗log234=0.81
4. 计算
Entropy(性别|抽烟)
p(抽烟=true)=0.5
p(抽烟=false)=0.5
Entropy(性别|抽烟)=0.5∗1.0+0.5∗0.81=0.905
5. 信息增益
Gain(性别,抽烟)=0.96−0.905=0.095
利用信息熵
分别计算各个属性的信息增益,去最大的那个属性作为节点label
过拟合
两个分类器A、B,A在训练集中的效果比B好,但是在测试集中比B差,我们说A过拟合。
限制决策树高度
剪枝
将两个叶子节点,合并后,按照少数服从多数得出label
需要增设一个校验集,用于剪枝过程中的误差比较。
当剪枝进行到在校验集上误差由减小到增大的拐点时,停止剪枝
处理连续性数据
采用信息增益衡量按照进行对阈值切分点后的数据集的纯度,采用信息增益比较大的。
贝叶斯分类器实现
package com.liuyanzuo.datamining.classification;
import java.util.*;
/**
* 朴素贝叶斯分类器实现
* Created by tempuser on 2017/1/19.
*/
public class NaiveBayesClassification {
//定义常量
public static final String NOT_DEFINE_ATTR="not build the attributeList";
public static final String SUCCESS="success";
//定义存储类别信息的结构
private Map<String,Map<String,Map<String,Integer>>> statisticsMsg;
//定义属性名称集合
private List<String> attributeList;
//每个label的数量统计
private Map<String,Integer> labelCountMap;
//每个属性可取值的范围
private Map<String,List<String>> attrValue;
//每个label的每个属性的百分比统计
private Map<String,Map<String,Map<String,Double>>> labelAttrPercentMap;
//label在属性的下标
private int labelIndex;
//数据的总数量
private int totalCount;
/**
* 构造分类器
* @param data
* @param labelIndex
*/
public String build(List<List<String>> data,int labelIndex){
if(null==attributeList || "".equals(attributeList)){
return NOT_DEFINE_ATTR;
}
this.labelIndex=labelIndex;
//初始化各个属性
statisticsMsg=new HashMap<>();
labelCountMap=new HashMap<>();
attrValue=new HashMap<>();
for(List<String> attributeLabelList : data){
//这行数据的标签
String label=attributeLabelList.get(labelIndex);
//统计这行数据的label
Integer labelPercentValue=labelCountMap.get(label);
if(labelPercentValue==null){
labelPercentValue=0;
}
labelPercentValue++;
labelCountMap.put(label,labelPercentValue);
totalCount++;
Map<String,Map<String,Integer>> labelMap= statisticsMsg.get(label);
if(null == labelMap){
labelMap=new HashMap<>();
statisticsMsg.put(label,labelMap);
}
for(int i=0;i<attributeLabelList.size();i++){
if(i != labelIndex){
//现在所在下标的属性名称
String attributeName=attributeList.get(i);
//现在所在下标的属性值
String attributeValue=attributeLabelList.get(i);
//统计属性的取值范围
List<String> attrValueList=attrValue.get(attributeName);
if(attrValueList==null){
attrValueList=new ArrayList<>();
}
if(!attrValueList.contains(attributeValue)){
attrValueList.add(attributeValue);
}
attrValue.put(attributeName,attrValueList);
Map<String,Integer> attributeMap=labelMap.get(attributeName);
if( null == attributeMap){
attributeMap=new HashMap<>();
labelMap.put(attributeList.get(i),attributeMap);
}
Integer attributeCountValue=attributeMap.get(attributeValue);
if(null==attributeCountValue){
attributeCountValue=0;
}
attributeCountValue++;
attributeMap.put(attributeValue,attributeCountValue);
}
}
}
labelAttrPercentMap=new HashMap<>();
//统计label百分比
Set<String> labelSet=statisticsMsg.keySet();
for(String label:labelSet){
//这个label的总长度
int labelCount=labelCountMap.get(label);
Map<String,Map<String,Integer>> statisticsLabelAttrMap=statisticsMsg.get(label);
//统计每个label下的各个属性的各个取值的数量
Map<String,Map<String,Double>> percentValue=new HashMap<>();
Set<String> attrSet=statisticsLabelAttrMap.keySet();
for(String attribute:attrSet){
Map<String,Integer> attributeValueMap=statisticsLabelAttrMap.get(attribute);
Set<String> attributeValueSet=attributeValueMap.keySet();
Map<String,Double> percentAttributeValueMap=new HashMap<>();
for(String attributeValue:attributeValueSet){
//最终属性取值的百分比
percentAttributeValueMap.put(attributeValue,attributeValueMap.get(attributeValue)/(labelCount*1.0));
}
percentValue.put(attribute,percentAttributeValueMap);
}
labelAttrPercentMap.put(label,percentValue);
}
return SUCCESS;
}
/**
* 对传入数据进行分类
* @param needClassify
*/
public Map<String,Double> classify(List<String> needClassify){
Map<String,Double> result=new HashMap<>();
if(null == statisticsMsg || statisticsMsg.size()==0){
return result;
}
for(String label:labelCountMap.keySet()){
double prediction=1.0;
for(int i=0;i<attributeList.size();i++){
//当前属性名称
String attrName=attributeList.get(i);
if(i != labelIndex){
//要做一个laplace平滑
Integer labelAttrPercentValue=statisticsMsg.get(label).get(attributeList.get(i)).get(needClassify.get(i));
if(labelAttrPercentValue==null){
labelAttrPercentValue=0;
}
prediction*=(labelAttrPercentValue+1.0)/(attrValue.get(attrName).size()*1.0+labelCountMap.get(label)*1.0);
}
}
result.put(label,prediction);
}
return result;
}
public Map<String, Map<String, Map<String, Integer>>> getStatisticsMsg() {
return statisticsMsg;
}
public void setStatisticsMsg(Map<String, Map<String, Map<String, Integer>>> statisticsMsg) {
this.statisticsMsg = statisticsMsg;
}
public List<String> getAttributeList() {
return attributeList;
}
public void setAttributeList(List<String> attributeList) {
this.attributeList = attributeList;
}
public Map<String, Integer> getLabelCountMap() {
return labelCountMap;
}
public Map<String, Map<String, Map<String, Double>>> getLabelAttrPercentMap() {
return labelAttrPercentMap;
}
public void setLabelAttrPercentMap(Map<String, Map<String, Map<String, Double>>> labelAttrPercentMap) {
this.labelAttrPercentMap = labelAttrPercentMap;
}
}