weka分类器-NaiveBayes

Weka中实现一个分类器比较重要的有一下3个方法:

buildClassifier(Instances instances)

double[] distributionForInstance(Instanceinstance)

double classifyInstance(Instance instance) throwsException;

其中后2个只要实现其中之一即可,因此我们在分析过程中主要围绕这几个方法进行。

 

朴素贝叶斯法是基于贝叶斯定理与特征条件独立假设的分类方法。简单来说,朴素贝叶斯分类器假设样本每个特征与其他特征都不相关。在进行贝叶斯分类时最重要的2个参数是类先验概率P(Ci)和类条件独立概率P(x|Ci)。

 

1.       重要参数

在NativeBayes中定义了2个字段

protected Estimator[][] m_Distributions; 其中保存的是用于计算P(x|Ci)参数;

protected Estimator m_ClassDistribution; 其中保存的是P(Ci)

 

2.       训练分类器

下面我们看一下其训练过程,主要在方法buildClassifier(Instances instances)中,该方法用于训练类条件概率或计算概率所需要的参数。其主要代码如下(我们只分析比较重要的部分):

public void buildClassifier(Instancesinstances) throws Exception {

……

                            for(int j = 0; j < m_Instances.numClasses(); j++) {

                                     //该循环主要用于为每个class创建Estimator类

                                     switch(attribute.type()) {

                                     caseAttribute.NUMERIC:

                                               if(m_UseKernelEstimator) {

                                                        m_Distributions[attIndex][j]= new KernelEstimator(numPrecision);

                                               }else {

                                                        m_Distributions[attIndex][j]= new NormalEstimator(numPrecision);

                                               }

                                               break;

                                     caseAttribute.NOMINAL:

                                               m_Distributions[attIndex][j]= new DiscreteEstimator(attribute.numValues(), true);

                                               break;

                                     default:

                                               thrownew Exception("Attribute type unknown to NaiveBayes");

                                     }

                            }

                            attIndex++;

                   }

 

                   //Compute counts

                   //此处扫描样本训练参数

                   EnumerationenumInsts = m_Instances.enumerateInstances();

                   while(enumInsts.hasMoreElements()) {

                            Instanceinstance = (Instance) enumInsts.nextElement();

                            updateClassifier(instance);

                   }

 

                   //Save space

                   m_Instances= new Instances(m_Instances, 0);

         }

 

从上面代码可以看出首先定义一个Estimator类型的数组用于存放类条件概率P(xi|C),根据估计算法的不同数组Estimator不同的元素可能为其子类中的一个(KernelEstimator,NormalEstimator或DiscreteEstimator)。

然后调用updateClassifier扫描所有的训练样本并更新m_Distributions[attIndex][j](表示属性attIndex在类j下的概率/参数)中的参数。

 

2.1    概率估计算法

针对不同类型的属性,类条件概率估计方法是不同的:

当属性类型是离散型时,类条件概率P(Xi =xi|C = c)可以根据类c中属性等于xi的训练实例的比例来估计。该代码如下:

public void addValue(double data, doubleweight) {

   

   m_Counts[(int)data] += weight;

   m_SumOfCounts += weight;

  }

这里data是对应属性的值,m_Counts记录该值出现的次数,m_SumOfCounts记录该属性所有值出现的总次数。对于weight,如果属性没有设置weight值的话为1,否则为设定的值,后面对于weight均如此。

当属性为连续性时

一种方法是将数据进行离散化,但是由于难以判断离散化的标准与力度,所以不常用。因此这里不做详细分析。

另一种方法是假设连续变量服从某种概率分布,然后使用训练样本估计分布的参数。这里采用2中算法估计概率参数。

a. 高斯分布估计;

高斯分布有2个参数,均值和标准差。这里根据训练样本对不同的类分别估计其和,则Xi的类条件概率为


主要代码如下:

public void addValue(double data, doubleweight) {

 

   if (weight == 0) {

     return;

   }

   data = round(data);

   m_SumOfWeights += weight;

   m_SumOfValues += data * weight;

   m_SumOfValuesSq += data * data * weight;

 

   if (m_SumOfWeights > 0) {

     m_Mean = m_SumOfValues / m_SumOfWeights; //计算样本均值

     double stdDev = Math.sqrt(Math.abs(m_SumOfValuesSq - m_Mean * m_SumOfValues)

                                                / m_SumOfWeights); //计算样本标准差

     if (stdDev > 1e-10) {

         m_StandardDev= Math.max(m_Precision / (2 * 3),

                                      // allow at most 3sd's within one interval

                                      stdDev);

     }

   }

  }

标准差计算公式如下:


计算结果保存在m_Mean和m_StandardDev中,后面在对样本预测分类的时候会用到。

 

b. 核密度估计;

虽然高斯密度估计是处理连续型属性的常用方法,但并不是虽有的情况都符合高斯分布。许多研究这已经提出了很多非参数密度估计方法。这里主要采用高斯核密度估计(当然也可以采用其它核函数),该方法看上去和高斯密度估计类似,只是它是一些列个函数估计的平均值,计算公式如下:


在该分类中i遍历class c中某个属性的所有值,,,其中range = (最大属性值 – 最小属性值),是class c中属性值的总个数。

因此在执行该算法过程中要保存class c中该属性所有的值,从而导致更高的空间复杂度及消耗更多的存储空间。

关于该算法,参考论文Estimating Continuous Distributions in Bayesian Classifiers

 

主要代码如下:

public void addValue(double data, doubleweight) {

                   if(weight == 0) {

                            return;

                   }

                   data= round(data);

                   intinsertIndex = findNearestValue(data);

                   if((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {

                  //存储读取的观察值

                            if(m_NumValues < m_Values.length) {

                                     intleft = m_NumValues - insertIndex;

                                     System.arraycopy(m_Values,insertIndex, m_Values, insertIndex + 1, left);

                                     System.arraycopy(m_Weights,insertIndex, m_Weights, insertIndex + 1, left);

 

                                     m_Values[insertIndex]= data;

                                     m_Weights[insertIndex]= weight;

                                     m_NumValues++;

                            }else {

                                    //空间不足,扩大空间

                                     double[]newValues = new double[m_Values.length * 2]; //扩大空间

                                     double[]newWeights = new double[m_Values.length * 2];

                                     intleft = m_NumValues - insertIndex;

                                     System.arraycopy(m_Values,0, newValues, 0, insertIndex);

                                     System.arraycopy(m_Weights,0, newWeights, 0, insertIndex);

                                     newValues[insertIndex]= data;

                                     newWeights[insertIndex]= weight;

                                     System.arraycopy(m_Values,insertIndex, newValues, insertIndex + 1, left);

                                     System.arraycopy(m_Weights,insertIndex, newWeights, insertIndex + 1, left);

                                     m_NumValues++;

                                     m_Values= newValues;

                                     m_Weights= newWeights;

                            }

                            if(weight != 1) {

                                     m_AllWeightsOne= false;

                            }

                   }else {

                            m_Weights[insertIndex]+= weight;

                            m_AllWeightsOne= false;

                   }

                   m_SumOfWeights+= weight;

                   doublerange = m_Values[m_NumValues - 1] - m_Values[0];  //计算range

                   if(range > 0) {

                            m_StandardDev= Math.max(range / Math.sqrt(m_SumOfWeights),

                            //allow at most 3 sds within one interval

                                               m_Precision/ (2 * 3));  //计算标准差

                   }

         }

m_Values中存储的是该属性的所有值,并且按照从小到大的顺序排列。

遍历完所有的instance则参数训练完成,可以对新的instance进行分类预测。

 

3.       样本预测分类

该部分主要利用朴素贝叶斯的朴素行假设,即个属性在类条件下相互独立,来进行简化计算。计算公式如下:


对于所有的类,P(X)均相同,所以类预测公式为:


下面对照代码进行分析:

public double[]distributionForInstance(Instance instance) throws Exception {

 

                   if(m_UseDiscretization) {

                            m_Disc.input(instance);

                            instance= m_Disc.output();

                   }

                   double[]probs = new double[m_NumClasses];

                   for(int j = 0; j < m_NumClasses; j++) {

                            probs[j]= m_ClassDistribution.getProbability(j); //计算P(Ci)

                   }

                   EnumerationenumAtts = instance.enumerateAttributes();

                   intattIndex = 0;

                   //扫描各个属性

                   while(enumAtts.hasMoreElements()) {

                            Attributeattribute = (Attribute) enumAtts.nextElement();

                            if(!instance.isMissing(attribute)) {

                                     doubletemp, max = 0;

                                     for(int j = 0; j < m_NumClasses; j++) {

                                               temp= Math.max(1e-75, Math.pow(

                                                                 m_Distributions[attIndex][j].getProbability(instance.value(attribute)),

                                                                 m_Instances.attribute(attIndex).weight()));//计算类条件概率P(xi|C)

                                               probs[j]*= temp;

                                               if(probs[j] > max) {

                                                        max= probs[j];

                                               }

                                     }

                                     if((max > 0) && (max < 1e-75)) {  //这里为了避免概率值过小,将其放大

                                               for(int j = 0; j < m_NumClasses; j++) {

                                                        probs[j]*= 1e75;

                                               }

                                     }

                            }

                            attIndex++;

                   }

 

                   //Display probabilities

                   Utils.normalize(probs);

                   returnprobs;

         }

计算P(xi|C)时,

离散型只需计算某个属性值出现的比例()m_Counts[(int)data] / m_SumOfCounts)即可:

public double getProbability(doubledata) {

   

   if (m_SumOfCounts == 0) {

     return 0;

   }

   return (double)m_Counts[(int)data] / m_SumOfCounts;

 }

对于连续型:

概率计算过程在方法getProbability(double data)实现,下面主要分析该方法。

对高斯分布,首先对数据进行标准化,然后利用高斯概率分布公式进行计算:


public double getProbability(doubledata) {

 

   data = round(data);

   double zLower = (data - m_Mean - (m_Precision / 2)) / m_StandardDev;  //对数据进行标准化

   double zUpper = (data - m_Mean + (m_Precision / 2)) / m_StandardDev;

   

   double pLower = Statistics.normalProbability(zLower); //调用高斯分布函数计算

   double pUpper = Statistics.normalProbability(zUpper);

   return pUpper - pLower;

 }

 

对于核密度估计:

回顾一下概率计算公式:


public double getProbability(doubledata) {

 

                   doubledelta = 0, sum = 0, currentProb = 0;

                   doublezLower = 0, zUpper = 0;

                   if(m_NumValues == 0) { //该属性没有值

                            zLower= (data - (m_Precision / 2)) / m_StandardDev;

                            zUpper= (data + (m_Precision / 2)) / m_StandardDev;

                            return(Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));

                   }

                   doubleweightSum = 0;

                   intstart = findNearestValue(data);

                  //下面开始扫描所有的属性值

                   for(int i = start; i < m_NumValues; i++) {

                            delta= m_Values[i] - data;

                            zLower= (delta - (m_Precision / 2)) / m_StandardDev;

                            zUpper= (delta + (m_Precision / 2)) / m_StandardDev;

                            currentProb= (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));//计算高斯概率值

                            sum+= currentProb * m_Weights[i]; //计算概率和

                            weightSum+= m_Weights[i];

                            if(currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {

                                     break;

                            }

                   }

                   for(int i = start - 1; i >= 0; i--) {

                            delta= m_Values[i] - data;

                            zLower= (delta - (m_Precision / 2)) / m_StandardDev;

                            zUpper= (delta + (m_Precision / 2)) / m_StandardDev;

                            currentProb= (Statistics.normalProbability(zUpper) - Statistics.normalProbability(zLower));

                            sum+= currentProb * m_Weights[i];

                            weightSum+= m_Weights[i];

                            if(currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {

                                     break;

                            }

                   }

                   returnsum / m_SumOfWeights; //求平均值

         }

 

计算得到类条件概率后进行累乘:

probs[j] *= temp

最后将probs作为结果返回给上一层,其中值最大的元素对应的类即为预测的类。

阅读更多
个人分类: weka源代码分析
上一篇weka源代码分析-总述
下一篇weka分类器-C4.5决策树
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

关闭
关闭
关闭