贝叶斯分类是一类分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类。而朴素贝叶斯分类是贝叶斯分类中最简单,也是常见的一种分类方法。
它的核心算法就是下面这个贝叶斯公式:
换个表达形式,如下:
我们最终求的p(类别|特征)即可,就相当于完成了我们的任务。
@Override
public void buildClassifier(Instances data) throws Exception
{
//检测分类器能否处理数据
getCapabilities().testWithFail(data);
//删除具有类别缺失值的实例
data=new Instances(data);
data.deleteWithMissingClass();
//保存类别的数量
m_NumClasses=data.numClasses();
//复制训练集
m_Instances=new Instances(data);
//如果指定,就对数据进行离散化
if(m_UseDiscretization)
{
m_Disc=new weka.filters.supervised.attribute.Discretize();
m_Disc.setInputFormat(data);
m_Instances=weka.filters.Filter.useFilter(m_Instances, m_Disc);
}
else
{
m_Disc=null;
}
//为概率分布预留空间
//类别条件概率分布P(X|Y)
m_Distributions=new Estimator[m_Instances.numAttributes()-1][m_Instances.numClasses()];
//类别分布P(Y)
m_ClassDistribution=new DiscreteEstimator(m_Instances.numClasses(), true);
int attIndex=0;
Enumeration enumeration=m_Instances.enumerateAttributes();
//循环处理每一个属性
while(enumeration.hasMoreElements())
{
Attribute attribute=(Attribute) enumeration.nextElement();
//如果属性是数值型,根据相邻值之间的差异,测定估计器数值精度
double numPrecision=DEFAULT_NUM_PRECISION;
if(attribute.type()==Attribute.NUMERIC)
{
//根据当前属性的值对数据集排序
m_Instances.sort(attribute);
//排序之后,当前属性缺失值的实例就排到最前
//这样,判断第一个样本是否有缺失值,就知道整体样本是否有缺失值
//如果有,就没有必要执行if后面的代码块
if((m_Instances.numInstances()>0) && !m_Instances.instance(0).isMissing(attribute))
{
//lastVal为后一个实例的当前属性值
double lastVal=m_Instances.instance(0).value(attribute);
//currentVal,为每个实例的当前属性值,deltaSum为差值
double currentVal,deltaSum=0;
//distinct为当前属性取不同值的数量
int distinct=0;
for(int i=1;i<m_Instances.numInstances();i++)
{
Instance currentInst=m_Instances.instance(i);
if(currentInst.isMissing(attribute))
{
break;
}
currentVal=currentInst.value(attribute);
//如果当前值与最后值不相等,则相减并将差值累加到deltaSum
if(currentVal!=lastVal)
{
deltaSum+=currentVal-lastVal;
lastVal=currentVal;
distinct++;
}
}
//最终的numPrecision就是deltaSum/distinct
if(distinct>0)
{
numPrecision=deltaSum/distinct;
}
}
}
//循环处理每一个类别标签
for(int j=0;j<m_Instances.numClasses();j++)
{
//判断当前属性的类型
switch(attribute.type())
{
//如果为连续的数值型属性,根据是否使用核估计器的选项,选择构建Kernelstimator对象还是NormalEstimator对象
//两者的构造函数都是使用numPrecision作为参数
case Attribute.NUMERIC:
if(m_UseKernelEstimator)
{
m_Distributions[attIndex][j]=new KernelEstimator(numPrecision);
}
else
{
m_Distributions[attIndex][j]=new NormalEstimator(numPrecision);
}
break;
case Attribute.NOMINAL:
m_Distributions[attIndex][j]=new DiscreteEstimator(attribute.numValues(), true);
break;
default:
throw new Exception("Attribute type unkown to my NB");
}
}
attIndex++;
}
//统计每一个实例
Enumeration enumInsts=m_Instances.enumerateInstances();
while (enumInsts.hasMoreElements())
{
Instance instance=(Instance) enumInsts.nextElement();
//调用updateClassifier方法,用实例更新分离器
updateClassifier(instance);
}
//节省空间
m_Instances=new Instances(m_Instances,0);
}
public void updateClassifier(Instance instance)
{
if(!instance.classIsMissing())
{
Enumeration enumAtts=m_Instances.enumerateAttributes();
int attIndex=0;
//循环处理没一个属性
while (enumAtts.hasMoreElements())
{
Attribute attribute = (Attribute) enumAtts.nextElement();
if(!instance.isMissing(attribute))
{
//m_Distributons第一个下标记为当亲属性下标记,第二个下标为类别值
//统计样本实例对应类别属性值的分布
//调用Estimator的AddValue方法将新数据值加入到当前评估器中
m_Distributions[attIndex][(int)instance.classValue()].addValue(instance.value(attribute),
instance.weight());
}
attIndex++;
}
//统计类别分布
m_ClassDistribution.addValue(instance.classValue(), instance.weight());
}
}
public double[] distributionForInstance(Instance instance) throws Exception
{
//如果使用useSupervisedDiscretization选项,就对实例进行离散化
if(m_UseDiscretization)
{
m_Disc.input(instance);
instance=m_Disc.output();
}
//类别的概率P(Y)
double probs[]=new double[m_NumClasses];
//循环得到每个类别的概率
for(int j=0;j<m_NumClasses;j++)
{
probs[j]=m_ClassDistribution.getProbability(j);
}
Enumeration enumAtts=instance.enumerateAttributes();
int attIndex=0;
//循环处理每个属性
while(enumAtts.hasMoreElements())
{
Attribute attribute=(Attribute) enumAtts.nextElement();
if(!instance.isMissing(attribute))
{
//temp为临时概率,max为当前最大概率
double temp,max=0;
for (int j = 0; j < m_NumClasses; j++)
{
//计算每个类别的条件概率P(X|Y)
temp=Math.max(1e-75, Math.pow(m_Distributions[attIndex][j].getProbability(instance.value(attribute)),
m_Instances.attribute(attIndex).weight()));
probs[j]*=temp;
//更新最大概率值
if(probs[j]>max)
{
max=probs[j];
}
if(Double.isNaN(probs[j]))
{
throw new Exception(
"Nan returned from estimator for atrribute "+
attribute.name()+":\n"+
m_Distributions[attIndex][j].toString());
}
}
if(max>0 && max<1e-75)
{
//防止概率下溢的危险
for(int j=0;j<m_NumClasses;j++)
{
probs[j]*=1e75;
}
}
}
attIndex++;
}
//概率规范化
Utils.normalize(probs);
return probs;
}