贝叶斯分类器-基础知识

贝叶斯分类器-基础知识

分类: 机器学习   37人阅读  评论(0)  收藏  举报

本文转载自:http://www.cnblogs.com/phinecos/archive/2008/10/21/1315948.html,作者:phinecos(洞庭散人)

=======================================================================================

Preface

       本文缘起于最近在读的一本书-- Tom M.Mitchell《机器学习》,书中第6章详细讲解了贝叶斯学习的理论知识,为了将其应用到实际中来,参考了网上许多资料,从而得此文。文章将分为两个部分,第一部分将介绍贝叶斯学习的相关理论(如果你对理论不感兴趣,请直接跳至第二部分<<基于朴素贝叶斯分类器的文本分类算法(下)>>)。第二部分讲如何将贝叶斯分类器应用到中文文本分类,随文附上示例代码。

 Introduction

我们在《概率论和数理统计》这门课的第一章都学过贝叶斯公式和全概率公式,先来简单复习下:

条件概率

定义 A, B是两个事件,且P(A)>0 P(BA)=P(AB)/P(A)为在条件A下发生的条件事件B发生的条件概率。

乘法公式 P(A)>0 则有P(AB)=P(BA)P(A)

全概率公式和贝叶斯公式

定义 S为试验E的样本空间,B1, B2, …BnE的一组事件,若BiBj=Ф, i≠j, i, j=1, 2, …,n; B1B2Bn=S则称B1, B2, …, Bn为样本空间的一个划分。

定理 设试验E的样本空间为,AE的事件,B1, B2, …,Bn为的一个划分,且P(Bi)>0 (i=1, 2, …n),则P(A)=P(AB1)P(B1)+P(AB2)+ …+P(ABn)P(Bn)称为全概率公式。

定理 设试验俄E的样本空间为SAE的事件,B1, B2, …,Bn为的一个划分,则

P(BiA)=P(ABi)P(Bi)/∑P(BAj)P(Aj)=P(BAi)P(Ai)/P(B)

称为贝叶斯公式。说明:ij均为下标,求和均是1n  

 下面我再举个简单的例子来说明下。

示例1

考虑一个医疗诊断问题,有两种可能的假设:(1)病人有癌症。(2)病人无癌症。样本数据来自某化验测试,它也有两种可能的结果:阳性和阴性。假设我们已经有先验知识:在所有人口中只有0.008的人患病。此外,化验测试对有病的患者有98%的可能返回阳性结果,对无病患者有97%的可能返回阴性结果。

上面的数据可以用以下概率式子表示:

P(cancer)=0.008,P(cancer)=0.992

P(阳性|cancer)=0.98,P(阴性|cancer)=0.02

P(阳性|cancer)=0.03P(阴性|cancer)=0.97

假设现在有一个新病人,化验测试返回阳性,是否将病人断定为有癌症呢?我们可以来计算极大后验假设:

P(阳性|cancer)p(cancer)=0.98*0.008 = 0.0078

P(阳性|cancer)*p(cancer)=0.03*0.992 = 0.0298

因此,应该判断为无癌症。

 贝叶斯学习理论

       贝叶斯是一种基于概率的学习算法,能够用来计算显式的假设概率,它基于假设的先验概率,给定假设下观察到不同数据的概率以及观察到的数据本身(后面我们可以看到,其实就这么三点东西,呵呵)。

      我们用P(h)表示没有训练样本数据前假设h拥有的初始概率,也就称为h的先验概率,它反映了我们所拥有的关于h是一个正确假设的机会的背景知识。当然如果没有这个先验知识的话,在实际处理中,我们可以简单地将每一种假设都赋给一个相同的概率。类似,P(D)代表将要观察的训练样本数据D的先验概率(也就是说,在没有确定某一个假设成立时D的概率)。然后是P(D/h),它表示假设h成立时观察到数据D的概率。在机器学习中,我们感兴趣的是P(h/D),也就是给定了一个训练样本数据D,判断假设h成立的概率,这也称之为后验概率,它反映了在看到训练样本数据D后假设h成立的置信度。(注:后验概率p(h/D)反映了训练数据D的影响,而先验概率p(h)是独立于D的)。

 

P(h|D) = P(D|h)P(h)/p(D),从贝叶斯公式可以看出,后验概率p(h/D)取决于P(D|h)P(h)这个乘积,呵呵,这就是贝叶斯分类算法的核心思想。我们要做的就是要考虑候选假设集合H,并在其中寻找当给定训练数据D时可能性最大的假设hh属于H)。

      简单点说,就是给定了一个训练样本数据(样本数据已经人工分类好了),我们应该如何从这个样本数据集去学习,从而当我们碰到新的数据时,可以将新数据分类到某一个类别中去。那可以看到,上面的贝叶斯理论和这个任务是吻合的。

朴素贝叶斯分类

 

也许你觉得这理论还不是很懂,那我再举个简单的例子,让大家对这个算法的原理有个快速的认识。(注:这个示例摘抄自《机器学习》这本书的第三章的表3-2.

假设给定了如下训练样本数据,我们学习的目标是根据给定的天气状况判断你对PlayTennis这个请求的回答是Yes还是No

Day

Outlook

Temperature

Humidity

Wind

PlayTennis

D1

Sunny

Hot

High

Weak

No

D2

Sunny

Hot

High

Strong

No

D3

Overcast

Hot

High

Weak

Yes

D4

Rain

Mild

High

Weak

Yes

D5

Rain

Cool

Normal

Weak

Yes

D6

Rain

Cool

Normal

Strong

No

D7

Overcast

Cool

Normal

Strong

Yes

D8

Sunny

Mild

High

Weak

No

D9

Sunny

Cool

Normal

Weak

Yes

D10

Rain

Mild

Normal

Weak

Yes

D11

Sunny

Mild

Normal

Strong

Yes

D12

Overcast

Mild

High

Strong

Yes

D13

Overcast

Hot

Normal

Weak

Yes

D14

Rain

Mild

High

Strong

No

 可以看到这里样本数据集提供了14个训练样本,我们将使用此表的数据,并结合朴素贝叶斯分类器来分类下面的新实例:

(Outlook = sunny,Temprature = cool,Humidity = high,Wind = strong)

我们的任务就是对此新实例预测目标概念PlayTennis的目标值(yesno).

由上面的公式可以得到:

可以得到:

      P(PlayTennis =yes) = 9/14 = 0.64,P(PlayTennis=no)=5/14 = 0.36

      P(Wind=Stong| PlayTennis =yes)=3/9=0.33,p(Wind=Stong| PlayTennis =no)=3/5 = 0.6

其他数据类似可得,代入后得到:

P(yes)P(Sunny|yes)P(Cool|yes)P(high|yes)P(Strong|yes) = 0.0053

P(no)P(Sunny|no)P(Cool|no)P(high|no)P(Strong|no)=0.0206

因此应该分类到no这一类中。

 

贝叶斯文本分类算法

      好了,现在开始进入本文的主旨部分:如何将贝叶斯分类器应用到中文文本的分类上来?

根据联合概率公式(全概率公式)

 

  

M——训练文本集合中经过踢出无用词去除文本预处理之后关键字的数量。

作者:洞庭散人

出处:http://phinecos.cnblogs.com/    

本博客遵从 Creative Commons Attribution 3.0 License,若用于非商业目的,您可以自由转载,但请保留原作者信息和文章链接URL。

 

贝叶斯分类器--原理流程应用

分类: 机器学习   75人阅读  评论(0)  收藏  举报

目录(?)[+]

本文转载自:http://www.cnblogs.com/leoo2sk/archive/2010/09/17/naive-bayesian-classifier.html,感谢原作者张洋。

==============================================================================

算法杂货铺——分类算法之朴素贝叶斯分类(Naive Bayesian classification)

0、写在前面的话

      我个人一直很喜欢算法一类的东西,在我看来算法是人类智慧的精华,其中蕴含着无与伦比的美感。而每次将学过的算法应用到实际中,并解决了实际问题后,那种快感更是我在其它地方体会不到的。

      一直想写关于算法的博文,也曾写过零散的两篇,但也许是相比于工程性文章来说太小众,并没有引起大家的兴趣。最近面临毕业找工作,为了能给自己增加筹码,决定再次复习算法方面的知识,我决定趁这个机会,写一系列关于算法的文章。这样做,主要是为了加强自己复习的效果,我想,如果能将复习的东西用自己的理解写成文章,势必比单纯的读书做题掌握的更牢固,也更能触发自己的思考。如果能有感兴趣的朋友从中有所收获,那自然更好。

      这个系列我将其命名为“算法杂货铺”,其原因就是这些文章一大特征就是“杂”,我不会专门讨论堆栈、链表、二叉树、查找、排序等任何一本数据结构教科书都会讲的基础内容,我会从一个“专题”出发,如概率算法、分类算法、NP问题、遗传算法等,然后做一个引申,可能会涉及到算法与数据结构、离散数学、概率论、统计学、运筹学、数据挖掘、形式语言与自动机等诸多方面,因此其内容结构就像一个杂货铺。当然,我会竭尽所能,尽量使内容“杂而不乱”。

1.1、摘要

      贝叶斯分类是一类分类算法的总称,这类算法均以贝叶斯定理为基础,故统称为贝叶斯分类。本文作为分类算法的第一篇,将首先介绍分类问题,对分类问题进行一个正式的定义。然后,介绍贝叶斯分类算法的基础——贝叶斯定理。最后,通过实例讨论贝叶斯分类中最简单的一种:朴素贝叶斯分类。

1.2、分类问题综述

      对于分类问题,其实谁都不会陌生,说我们每个人每天都在执行分类操作一点都不夸张,只是我们没有意识到罢了。例如,当你看到一个陌生人,你的脑子下意识判断TA是男是女;你可能经常会走在路上对身旁的朋友说“这个人一看就很有钱、那边有个非主流”之类的话,其实这就是一种分类操作。

      从数学角度来说,分类问题可做如下定义:

      已知集合:,确定映射规则,使得任意有且仅有一个使得成立。(不考虑模糊数学里的模糊集情况)

      其中C叫做类别集合,其中每一个元素是一个类别,而I叫做项集合,其中每一个元素是一个待分类项,f叫做分类器。分类算法的任务就是构造分类器f。

      这里要着重强调,分类问题往往采用经验性方法构造映射规则,即一般情况下的分类问题缺少足够的信息来构造100%正确的映射规则,而是通过对经验数据的学习从而实现一定概率意义上正确的分类,因此所训练出的分类器并不是一定能将每个待分类项准确映射到其分类,分类器的质量与分类器构造方法、待分类数据的特性以及训练样本数量等诸多因素有关。

      例如,医生对病人进行诊断就是一个典型的分类过程,任何一个医生都无法直接看到病人的病情,只能观察病人表现出的症状和各种化验检测数据来推断病情,这时医生就好比一个分类器,而这个医生诊断的准确率,与他当初受到的教育方式(构造方法)、病人的症状是否突出(待分类数据的特性)以及医生的经验多少(训练样本数量)都有密切关系。

1.3、贝叶斯分类的基础——贝叶斯定理

      每次提到贝叶斯定理,我心中的崇敬之情都油然而生,倒不是因为这个定理多高深,而是因为它特别有用。这个定理解决了现实生活里经常遇到的问题:已知某条件概率,如何得到两个事件交换后的概率,也就是在已知P(A|B)的情况下如何求得P(B|A)。这里先解释什么是条件概率:

      表示事件B已经发生的前提下,事件A发生的概率,叫做事件B发生下事件A的条件概率。其基本求解公式为:

      贝叶斯定理之所以有用,是因为我们在生活中经常遇到这种情况:我们可以很容易直接得出P(A|B),P(B|A)则很难直接得出,但我们更关心P(B|A),贝叶斯定理就为我们打通从P(A|B)获得P(B|A)的道路。

      下面不加证明地直接给出贝叶斯定理:

      

1.4、朴素贝叶斯分类

1.4.1、朴素贝叶斯分类的原理与流程

      朴素贝叶斯分类是一种十分简单的分类算法,叫它朴素贝叶斯分类是因为这种方法的思想的很朴素,朴素贝叶斯的思想基础是这样的:对于给出的待分类项,求解在此项出现的条件下各个类别出现的概率,哪个最大,就认为此待分类项属于哪个类别。通俗来说,就好比这么个道理,你在街上看到一个黑人,我问你你猜这哥们哪里来的,你十有八九猜非洲。为什么呢?因为黑人中非洲人的比率最高,当然人家也可能是美洲人或亚洲人,但在没有其它可用信息下,我们会选择条件概率最大的类别,这就是朴素贝叶斯的思想基础。

      朴素贝叶斯分类的正式定义如下:

      1、设为一个待分类项,而每个a为x的一个特征属性。

      2、有类别集合

      3、计算

      4、如果,则

      那么现在的关键就是如何计算第3步中的各个条件概率。我们可以这么做:

      1、找到一个已知分类的待分类项集合,这个集合叫做训练样本集。

      2、统计得到在各类别下各个特征属性的条件概率估计。即

      3、如果各个特征属性是条件独立的,则根据贝叶斯定理有如下推导:

      

      因为分母对于所有类别为常数,因为我们只要将分子最大化皆可。又因为各特征属性是条件独立的,所以有:

      

      根据上述分析,朴素贝叶斯分类的流程可以由下图表示(暂时不考虑验证):

      可以看到,整个朴素贝叶斯分类分为三个阶段:

      第一阶段——准备工作阶段,这个阶段的任务是为朴素贝叶斯分类做必要的准备,主要工作是根据具体情况确定特征属性,并对每个特征属性进行适当划分,然后由人工对一部分待分类项进行分类,形成训练样本集合。这一阶段的输入是所有待分类数据,输出是特征属性和训练样本。这一阶段是整个朴素贝叶斯分类中唯一需要人工完成的阶段,其质量对整个过程将有重要影响,分类器的质量很大程度上由特征属性、特征属性划分及训练样本质量决定。

      第二阶段——分类器训练阶段,这个阶段的任务就是生成分类器,主要工作是计算每个类别在训练样本中的出现频率及每个特征属性划分对每个类别的条件概率估计,并将结果记录。其输入是特征属性和训练样本,输出是分类器。这一阶段是机械性阶段,根据前面讨论的公式可以由程序自动计算完成。

      第三阶段——应用阶段。这个阶段的任务是使用分类器对待分类项进行分类,其输入是分类器和待分类项,输出是待分类项与类别的映射关系。这一阶段也是机械性阶段,由程序完成。

1.4.2、估计类别下特征属性划分的条件概率及Laplace校准

      这一节讨论P(a|y)的估计。

      由上文看出,计算各个划分的条件概率P(a|y)是朴素贝叶斯分类的关键性步骤,当特征属性为离散值时,只要很方便的统计训练样本中各个划分在每个类别中出现的频率即可用来估计P(a|y),下面重点讨论特征属性是连续值的情况。

      当特征属性为连续值时,通常假定其值服从高斯分布(也称正态分布)。即:

      

      而

      因此只要计算出训练样本中各个类别中此特征项划分的各均值和标准差,代入上述公式即可得到需要的估计值。均值与标准差的计算在此不再赘述。

      另一个需要讨论的问题就是当P(a|y)=0怎么办,当某个类别下某个特征项划分没有出现时,就是产生这种现象,这会令分类器质量大大降低。为了解决这个问题,我们引入Laplace校准,它的思想非常简单,就是对没类别下所有划分的计数加1,这样如果训练样本集数量充分大时,并不会对结果产生影响,并且解决了上述频率为0的尴尬局面。

1.4.3、朴素贝叶斯分类实例:检测SNS社区中不真实账号

      下面讨论一个使用朴素贝叶斯分类解决实际问题的例子,为了简单起见,对例子中的数据做了适当的简化。

      这个问题是这样的,对于SNS社区来说,不真实账号(使用虚假身份或用户的小号)是一个普遍存在的问题,作为SNS社区的运营商,希望可以检测出这些不真实账号,从而在一些运营分析报告中避免这些账号的干扰,亦可以加强对SNS社区的了解与监管。

      如果通过纯人工检测,需要耗费大量的人力,效率也十分低下,如能引入自动检测机制,必将大大提升工作效率。这个问题说白了,就是要将社区中所有账号在真实账号和不真实账号两个类别上进行分类,下面我们一步一步实现这个过程。

      首先设C=0表示真实账号,C=1表示不真实账号。

      1、确定特征属性及划分

      这一步要找出可以帮助我们区分真实账号与不真实账号的特征属性,在实际应用中,特征属性的数量是很多的,划分也会比较细致,但这里为了简单起见,我们用少量的特征属性以及较粗的划分,并对数据做了修改。

      我们选择三个特征属性:a1:日志数量/注册天数,a2:好友数量/注册天数,a3:是否使用真实头像。在SNS社区中这三项都是可以直接从数据库里得到或计算出来的。

      下面给出划分:a1:{a<=0.05, 0.05<a<0.2, a>=0.2},a1:{a<=0.1, 0.1<a<0.8, a>=0.8},a3:{a=0(不是),a=1(是)}。

      2、获取训练样本

      这里使用运维人员曾经人工检测过的1万个账号作为训练样本。

      3、计算训练样本中每个类别的频率

      用训练样本中真实账号和不真实账号数量分别除以一万,得到:

      

      

      4、计算每个类别条件下各个特征属性划分的频率

      

      

      

      

      

      

      

      

      

      

      

      

      

      

      

      

      5、使用分类器进行鉴别

      下面我们使用上面训练得到的分类器鉴别一个账号,这个账号使用非真实头像,日志数量与注册天数的比率为0.1,好友数与注册天数的比率为0.2。

      

      

      可以看到,虽然这个用户没有使用真实头像,但是通过分类器的鉴别,更倾向于将此账号归入真实账号类别。这个例子也展示了当特征属性充分多时,朴素贝叶斯分类对个别属性的抗干扰性。

1.5、分类器的评价

      虽然后续还会提到其它分类算法,不过这里我想先提一下如何评价分类器的质量。

      首先要定义,分类器的正确率指分类器正确分类的项目占所有被分类项目的比率。

      通常使用回归测试来评估分类器的准确率,最简单的方法是用构造完成的分类器对训练数据进行分类,然后根据结果给出正确率评估。但这不是一个好方法,因为使用训练数据作为检测数据有可能因为过分拟合而导致结果过于乐观,所以一种更好的方法是在构造初期将训练数据一分为二,用一部分构造分类器,然后用另一部分检测分类器的准确率。

Creative Commons License

本文基于署名-非商业性使用 3.0许可协议发布,欢迎转载,演绎,但是必须保留本文的署名张洋(包含链接),且不得用于商业目的。如您有任何疑问或者授权方面的协商,请与我联系

 

贝叶斯分类器--概念

分类: 机器学习   65人阅读  评论(0)  收藏  举报

本文转载自:http://blog.csdn.net/caiye917015406/article/details/7884293,谢谢原作者!

============================================================================

这几天在学习贝叶斯分类,据说它的文本分析很给力,主要是应用简单,所以就小试以下。。。。

首先看一下贝叶斯应用的一个小例子:

一个士兵射击,分别在100,200,300处射击击的概率是0.7,0.2,0.1,而在各处射中目标的概率是0.6,0.2,0.04。现在目标已被击毁,求士兵在200米击中的概率?

这个要用到贝叶斯,设A1,A2,A3分别为士兵在100,200,300处射击,B为击中目标。

   则P(A1)=0.7,P(A2)=0.2,P(A3)=0.1。P(B|A1)=0.6,P(B|A2)=0.2,P(B|A3)=0.04。由贝叶斯公式可知

           P(A2|B)=(P(A2)*P(B|A2))/(P(A1)*P(B|A1)+P(A2)*P(B|A2)+P(A3)*P(B|A3))=(0.2*0.2)/(0.7*0.6+0.2*0.2+0.1*0.04)=0.08;

以上是贝叶斯的一个小应用,下面就详细的学习贝叶斯(本人是菜鸟,文中大部分内容均是借鉴,如有不对,大家指出)

一贝叶斯公式

   由以上我们已经可以看出贝叶斯公式,这里给出更一般的公式:

对于各式的解释,可以见例题,应该就没问题了。

二贝叶斯分类

     如果把样本属于某个类别作为条件,样本的特征向量取值作为结果,则模式识别的分类决策过程也可以看作是一种根据结果推测条件的推理过程。它可以分为两种类型:
     一确定性分类决策:
      特征空间由决策边界划分为多个决策区域,当样本属于某类时,其特征向量一定落入对应的决策区域中,当样本不属于某类时,其特征向量一定不会落入对应的决策区域中;现有待识别的样本特征向量落入了某决策区域中,则它一定属于对应的类。

       二随机性分类决策:
       特征空间中有多个类,当样本属于某类时,其特征向量会以一定的概率取得不同的值;现有待识别的样本特征向量取了某值,则它按不同概率有可能属于不同的类,分类决策将它按概率的大小划归到某一类别中。

     对于随机性分类决策,可以利用贝叶斯公式来计算样本属于各类的后验概率:

   

三贝叶斯分类器

   1最小错误率贝叶斯分类器

     当已知类别出现的先验概率P(Wi)和每个类别在样本中的概率为P(x|Wi)时,已经求的后验概率P(Wi|x).对于如此,利用最小错误率贝叶斯分类器的原理,可以做出以下判段:

       两类问题时,当P(Wi|x)>P(Wj|x)时,判决属于类别Wi.

       对于多类情况,当P(Wi|x)为所有中最大的,则属于Wi。

用图表可以很清晰的看出其分界:

二最大似然比贝叶斯分类器

   

三最小风险贝叶斯分类器

   在最小错误率贝叶斯分类器分类时,仅考虑了样本属于每一类的后验概率最初分类决策,而没有考虑每一种分类决策的风险。例如针对某项检测指标进行癌症的诊断,如果计算出患者癌症和未患癌症的后验率均为50%,如果患者真实情况患了癌症,此时做出未患的诊断则会延误时机,比做出患癌症的诊断带来更为严重的后果。

  于是,在这种情况下,要做改进。因此,在获得样本属于每一类的后验概率后,需要综合考虑各种分类决策的多带来的风险,选择分类风险最小的决策,这就是最小风险贝叶斯分类器。

 

这以上是贝叶斯的一般概念,对于贝叶斯分类器的构造还需要对参数进行估计,(未完待续)

 

贝叶斯分类器--文本分类的C语言实现

分类: 机器学习   86人阅读  评论(0)  收藏  举报

本文转载自:http://blog.csdn.net/caiye917015406/article/details/7887221,谢谢原作者!

==============================================================================

第一个是用c语言做的关于文本的分类,主要是对待分类文本所有单词在模板中概率的后验计算。算法比较简单,从网上下的(没记下地址,若不愿意公开,请留言,自当处理),稍作了一点修改。。,等有时间可以实现垃圾邮件的分类,利用斯坦福机器学习公开课中方法,统计高频词,利用朴素贝叶斯。等有时间和大家分享。

[cpp]  view plain copy
  1. #include <stdio.h>  
  2. #include <string.h>  
  3. #include <direct.h> //_getcwd(), _chdir()  
  4. #include <stdlib.h> //_MAX_PATH, system()  
  5. #include <io.h> //_finddata_t, _findfirst(), _findnext(), _findclose()  
  6. #include<iostream>  
  7. using namespace std;  
  8. //#include<fstream>  
  9. char vocabulary[1000][20];/*声明公有二维数组,用来存储分割好的单词*/  
  10.   
  11.   
  12. /*=================将要分类的文本分割成单词存储在二维数组vocabulary中================*/  
  13. //@输入参数:要分类的文本  
  14. //@输出参数:该文本中总单词数  
  15.   
  16. int SplitToWord(char text[])  
  17. {  
  18. int i=0;  
  19. char seps[]=", .\n"/*定义单词的分隔符*/   
  20. char *substring;   
  21.   
  22. /******利用分隔符将文本内容分割成单词并存储******/  
  23. substring=strtok(text,seps);   
  24. while(substring!=NULL)   
  25. {     
  26.    strcpy(vocabulary[i],substring);//将单词存储到vocabulary数组中   
  27.    substring=strtok(NULL,seps);   
  28.    i++;  
  29. }  
  30. return i; //返回一共多少个单词  
  31. }  
  32.   
  33.   
  34. /*===============================计算该目录下的文件数================================*/  
  35. //@输入参数:无  
  36. //@输出参数:该目录下.txt文件数  
  37.   
  38. int CountDirectory()  
  39. {  
  40. int count=0; //txt文件计数器  
  41. long hFile;  
  42.     _finddata_t fileinfo;  
  43.   
  44. /********查找.txt文件,记录文件数**********/  
  45.     if ((hFile=_findfirst("*.txt",&fileinfo))!=-1L)  
  46.     {  
  47.         do  
  48.         {              
  49.     count++;  
  50.         } while (_findnext(hFile,&fileinfo) == 0);  
  51. }  
  52. return count;  
  53. }  
  54.   
  55.   
  56. /*===================================计算某类别中∏P(ai|vj)===================================*/  
  57. //@输入参数:分类文本中单词数  
  58. //@输出参数:该类别下∏P(ai|vj)  
  59.   
  60. float CalculateWordProbability(int wordCount)  
  61. {  
  62. int countSame; //分类文本中的某单词在所有训练样本中出现次数  
  63. int countAll=0; //训练样本中总单词数  
  64. char token;  
  65. FILE *fp;  
  66. float wordProbability=1; //为后面联乘做准备  
  67. int i,j;  
  68. long hFile;  
  69.     _finddata_t fileinfo;  
  70.   
  71.   
  72. for(j=0;j<wordCount;j++) //对于分类样本中的每一个单词  
  73. {  
  74.    countSame=0;  
  75.    countAll=0;  
  76.    if((hFile=_findfirst("*.txt",&fileinfo))!=-1L) //对于该类别下每一个.txt文本  
  77.    {  
  78.     do  
  79.     {  
  80.      if((fp=fopen(fileinfo.name,"r"))==NULL) //是否能打开该文本  
  81.      {  
  82.       printf("Sorry!Cannot open the file!\n");  
  83.       exit(0);  
  84.      }  
  85.   
  86.      /********存储此.txt文件中每个单词并与分类文本的单词作比较*******/  
  87.      while((token = fgetc(fp)) != EOF)   
  88.      {  
  89.       char keyword[1024];   
  90.       i = 0;   
  91.        
  92.       keyword[0] = token; // 将每个词第一个字符赋给数组第一个元素  
  93.       while ((keyword[++i] = fgetc(fp)) != ' ' && keyword[i] != '\t' && keyword[i] != EOF && keyword[i] != '\n'); // 开始读字符,直到遇到空白符,说明找到一个词   
  94.       keyword[i] = '\0';// 加结束符  
  95.       countAll++;  
  96.   
  97.       if (strcmp(keyword,vocabulary[j]) == 0) //比较两个单词是否相同  
  98.        countSame++;  
  99.      }  
  100.      fclose(fp);  
  101.   
  102.     }while (_findnext(hFile,&fileinfo) == 0);   
  103.    }  
  104.    wordProbability*=(float)(countSame+1)/(float)(wordCount+countAll)*300; //计算∏P(wj|vi),为了扩大效果而*380  
  105. }  
  106.   
  107. return wordProbability;  
  108. }  
  109.     
  110.   
  111. /*============================计算每个类别的最终概率输出结果===============================*/  
  112. //@输入参数:分类文本中单词数  
  113.     
  114. void CalculateProbability(int wordCount,int num)  
  115. {  
  116. /*********将类别表存储在二维数组中*********/  
  117. FILE *fp;  
  118. char classList[10][20]; //类别列表  
  119.     char ch;    //临时读取字符使用  
  120.     int index=0; //classList的行标  
  121.     int className_c=0; //classList的列标  
  122.   
  123. if((fp=fopen("ClassList.txt","r"))==NULL)  
  124.     {  
  125.         printf("Failed to open the file: ClassList.txt.\n");  
  126.     }  
  127.     ch = fgetc(fp);  
  128.     while(ch!=EOF)  
  129.     {  
  130.         if(ch!='\n')  
  131.         {  
  132.             classList[index][className_c]=ch;  
  133.             className_c++;  
  134.         }  
  135.         else  
  136.         {  
  137.             classList[index][className_c]='\0';  
  138.             index++;  
  139.             className_c=0;  
  140.         }  
  141.    ch = fgetc(fp);  
  142. }  
  143.   
  144. /********计算总文本数和每个类别下的文本数、∏P(ai|vj)********/  
  145. int txtCount[10]; //每个类别下的训练文本数  
  146. int countAll=0; //训练集中总文本数  
  147. float wordProbability[10]; //每个类别的单词概率,即∏P(ai|vj)  
  148.   
  149. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\1")) //更改当前绝对路径  
  150.      printf("系统找不到指定路径!\n");  
  151. else  
  152. {  
  153.    txtCount[0]=CountDirectory(); //获取该类别下.txt文件数  
  154.    countAll+=txtCount[0];  
  155.    wordProbability[0]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  156. }  
  157. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\2")) //更改当前绝对路径  
  158.    printf("系统找不到指定路径!\n");  
  159. else  
  160. {  
  161.    txtCount[1]=CountDirectory(); //获取该类别下.txt文件数  
  162.    countAll+=txtCount[1];  
  163.    wordProbability[1]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  164. }  
  165. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\3")) //更改当前绝对路径  
  166.      printf("系统找不到指定路径!\n");  
  167. else  
  168. {  
  169.    txtCount[2]=CountDirectory(); //获取该类别下.txt文件数  
  170.    countAll+=txtCount[2];  
  171.    wordProbability[2]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  172. }  
  173. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\4")) //更改当前绝对路径  
  174.      printf("系统找不到指定路径!\n");  
  175. else  
  176. {  
  177.    txtCount[3]=CountDirectory(); //获取该类别下.txt文件数  
  178.    countAll+=txtCount[3];  
  179.    wordProbability[3]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  180. }  
  181. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\5")) //更改当前绝对路径  
  182.      printf("系统找不到指定路径!\n");  
  183. else  
  184. {  
  185.    txtCount[4]=CountDirectory(); //获取该类别下.txt文件数  
  186.    countAll+=txtCount[4];  
  187.    wordProbability[4]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  188. }  
  189. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\6")) //更改当前绝对路径  
  190.      printf("系统找不到指定路径!\n");  
  191. else  
  192. {  
  193.    txtCount[5]=CountDirectory(); //获取该类别下.txt文件数  
  194.    countAll+=txtCount[5];  
  195.    wordProbability[5]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  196. }  
  197. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\7")) //更改当前绝对路径  
  198.      printf("系统找不到指定路径!\n");  
  199. else  
  200. {  
  201.    txtCount[6]=CountDirectory(); //获取该类别下.txt文件数  
  202.    countAll+=txtCount[6];  
  203.    wordProbability[6]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  204. }  
  205. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\8")) //更改当前绝对路径  
  206.      printf("系统找不到指定路径!\n");  
  207. else  
  208. {  
  209.    txtCount[7]=CountDirectory(); //获取该类别下.txt文件数  
  210.    countAll+=txtCount[7];  
  211.    wordProbability[7]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  212. }  
  213. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\9")) //更改当前绝对路径  
  214.      printf("系统找不到指定路径!\n");  
  215. else  
  216. {  
  217.    txtCount[8]=CountDirectory(); //获取该类别下.txt文件数  
  218.    countAll+=txtCount[8];  
  219.    wordProbability[8]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  220. }  
  221. if(_chdir("D:\\openCV\\openCVProject\\openCVtext\\贝叶斯(文本分类)—c语言\\example\\10")) //更改当前绝对路径  
  222.      printf("系统找不到指定路径!\n");  
  223. else  
  224. {  
  225.    txtCount[9]=CountDirectory(); //获取该类别下.txt文件数  
  226.    countAll+=txtCount[9];  
  227.    wordProbability[9]=CalculateWordProbability(wordCount); //获取该类别下∏P(wj|vi)  
  228. }  
  229.   
  230. /*******计算先验概率和最终概率并输出分类结果*******/  
  231. float max=0;  
  232. int classNo=0;  
  233. float priorProbability[10];  
  234. float finalProbability[10];  
  235.   
  236. for(int i=0;i<num;i++)   
  237. {  
  238.    priorProbability[i]=(float)txtCount[i]/(float)countAll; //先验概率  
  239.    finalProbability[i]=priorProbability[i]*wordProbability[i]; //最终概率  
  240.    if(finalProbability[i]>max) //找到最大概率并记录  
  241.    {  
  242.     max=finalProbability[i];  
  243.     classNo=i;  
  244.    }  
  245.    printf("该文本为类别%s的概率为:%.5e\n",classList[i],finalProbability[i]); //输出每个类别的最终概率  
  246. }  
  247. printf("\n经分析,该文本最有可能为%s类文本!\n",classList[classNo]); //输出最后分类结果  
  248. }  
  249.   
  250.   
  251. /*===================调用文本分割函数和计算最终概率函数======================*/  
  252. //@输入参数:分类文本  
  253.   
  254. void NaiveBayesClassifier(char text[],int num)  
  255. {  
  256. int vocabularyCount;//分类样本中单词数  
  257.   
  258. vocabularyCount=SplitToWord(text); //对要分类的文本进行单词分割,结果存储在vocabulary数组中,返回分类样本中单词数  
  259. CalculateProbability(vocabularyCount,num); //计算最终概率  
  260. }  
  261.   
  262.   
  263. /*===================程序入口====================*/  
  264. int main()  
  265. {  
  266.    FILE *fp;  
  267.    if((fp=fopen("text.txt","r"))==NULL)  
  268.    {  
  269.         printf("Failed to open the file: ClassList.txt.\n");  
  270.    }  
  271.    char ch = fgetc(fp);  
  272.    int i=0;  
  273.    while(ch!=EOF)  
  274.    {  
  275.        ch = fgetc(fp);  
  276.        i++;  
  277.    }  
  278.    char *text=new char(i+1);  
  279.    fseek(fp,0,SEEK_SET);//  
  280.    ch = fgetc(fp);  
  281.    int j=0;  
  282.    while(ch!=EOF)  
  283.    {  
  284.        ch = fgetc(fp);  
  285.        cout<<ch;  
  286.        text[j]=ch;  
  287.        j++;  
  288.    }  
  289.   // char text[]=new char(i);;  
  290.    int num = 2;  
  291.   
  292.    NaiveBayesClassifier(text,num); /*调用朴素贝叶斯分类函数,返回最终分类结果*/  
  293. return 1;  
  294. }  
  295.    

贝叶斯分类器--文本分类应用

分类: 机器学习   68人阅读  评论(0)  收藏  举报

本文转载自:http://www.cnblogs.com/phinecos/archive/2008/10/21/1316044.html,谢谢原作者!

源代码下载:NaviveBayesClassify.rar 

Preface

文本的分类和聚类是一个比较有意思的话题,我以前也写过一篇blog基于K-Means的文本聚类算法》,加上最近读了几本数据挖掘和机器学习的书籍,因此很想写点东西来记录下学习的所得。

在本文的上半部分《基于朴素贝叶斯分类器的文本分类算法(上)》一文中简单介绍了贝叶斯学习的基本理论,这一篇将展示如何将该理论运用到中文文本分类中来,具体的文本分类原理就不再介绍了,在上半部分有,也可以参见代码的注释。

文本特征向量

文本特征向量可以描述为文本中的字/词构成的属性。例如给出文本:

Good good study,Day day up.

可以获得该文本的特征向量集:{ Good, good, study, Day, day , up.}

朴素贝叶斯模型是文本分类模型中的一种简单但性能优越的的分类模型。为了简化计算过程,假定各待分类文本特征变量是相互独立的,即朴素贝叶斯模型的假设。相互独立表明了所有特征变量之间的表述是没有关联的。如上例中,[good][study]这两个特征变量就是没有任何关联的。

在上例中,文本是英文,但由于中文本身是没有自然分割符(如空格之类符号),所以要获得中文文本的特征变量向量首先需要对文本进行中文分词

中文分词

      这里采用极易中文分词组件,这个中文分词组件可以免费使用,提供Lucene接口,跨平台,性能可靠。

复制代码
package com.vista;
import java.io.IOException;      
import jeasy.analysis.MMAnalyzer;

/* *
* 中文分词器
*/
public   class  ChineseSpliter 
{
    
/* *
    * 对给定的文本进行中文分词
    * @param text 给定的文本
    * @param splitToken 用于分割的标记,如"|"
    * @return 分词完毕的文本
    
*/
    
public   static  String split(String text,String splitToken)
    {
        String result 
=   null ;
        MMAnalyzer analyzer 
=   new  MMAnalyzer();      
        
try       
        {
            result 
=  analyzer.segment(text, splitToken);    
        }      
        
catch  (IOException e)      
        {     
            e.printStackTrace();     
        }     
        
return  result;
    }
}
复制代码

停用词处理

      去掉文档中无意思的词语也是必须的一项工作,这里简单的定义了一些常见的停用词,并根据这些常用停用词在分词时进行判断。

复制代码
package com.vista;

/* *
* 停用词处理器
* @author phinecos 

*/
public   class  StopWordsHandler 
{
    
private   static  String stopWordsList[]  = { " " " 我们 " , " " , " 自己 " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , " " , "" }; // 常用停用词
     public   static  boolean IsStopWord(String word)
    {
        
for ( int  i = 0 ;i < stopWordsList.length; ++ i)
        {
            
if (word.equalsIgnoreCase(stopWordsList[i]))
                
return   true ;
        }
        
return   false ;
    }
}
复制代码

训练集管理器

      我们的系统首先需要从训练样本集中得到假设的先验概率和给定假设下观察到不同数据的概率。

复制代码
package  com.vista;
import  java.io.BufferedReader;
import  java.io.File;
import  java.io.FileInputStream;
import  java.io.FileNotFoundException;
import  java.io.IOException;
import  java.io.InputStreamReader;
import  java.util.Properties;
import  java.util.logging.Level;
import  java.util.logging.Logger;
/**
* 训练集管理器
*/
public   class  TrainingDataManager 
{
    
private  String[] traningFileClassifications; // 训练语料分类集合
     private  File traningTextDir; // 训练语料存放目录
     private   static  String defaultPath  =   " D:\\TrainningSet " ;
    
    
public  TrainingDataManager() 
    {
        traningTextDir 
=   new  File(defaultPath);
        
if  ( ! traningTextDir.isDirectory()) 
        {
            
throw   new  IllegalArgumentException( " 训练语料库搜索失败! [ "   + defaultPath  +   " ] " );
        }
        
this .traningFileClassifications  =  traningTextDir.list();
    }
    
/**
    * 返回训练文本类别,这个类别就是目录名
    * 
@return  训练文本类别
    
*/
    
public  String[] getTraningClassifications() 
    {
        
return   this .traningFileClassifications;
    }
    
/**
    * 根据训练文本类别返回这个类别下的所有训练文本路径(full path)
    * 
@param  classification 给定的分类
    * 
@return  给定分类下所有文件的路径(full path)
    
*/
    
public  String[] getFilesPath(String classification) 
    {
        File classDir 
=   new  File(traningTextDir.getPath()  + File.separator  + classification);
        String[] ret 
=  classDir.list();
        
for  ( int  i  =   0 ; i  <  ret.length; i ++
        {
            ret[i] 
=  traningTextDir.getPath()  + File.separator  + classification  + File.separator  + ret[i];
        }
        
return  ret;
    }
    
/**
    * 返回给定路径的文本文件内容
    * 
@param  filePath 给定的文本文件路径
    * 
@return  文本内容
    * 
@throws  java.io.FileNotFoundException
    * 
@throws  java.io.IOException
    
*/
    
public   static  String getText(String filePath)  throws  FileNotFoundException,IOException 
    {
        InputStreamReader isReader 
= new  InputStreamReader( new  FileInputStream(filePath), " GBK " );
        BufferedReader reader 
=   new  BufferedReader(isReader);
        String aline;
        StringBuilder sb 
=   new  StringBuilder();
        
while  ((aline  =  reader.readLine())  !=   null )
        {
            sb.append(aline 
+   "   " );
        }
        isReader.close();
        reader.close();
        
return  sb.toString();
    }
    
/**
    * 返回训练文本集中所有的文本数目
    * 
@return  训练文本集中所有的文本数目
    
*/
    
public   int  getTrainingFileCount()
    {
        
int  ret  =   0 ;
        
for  ( int  i  =   0 ; i  <  traningFileClassifications.length; i ++ )
        {
            ret 
+= getTrainingFileCountOfClassification(traningFileClassifications[i]);
        }
        
return  ret;
    }
    
/**
    * 返回训练文本集中在给定分类下的训练文本数目
    * 
@param  classification 给定的分类
    * 
@return  训练文本集中在给定分类下的训练文本数目
    
*/
    
public   int  getTrainingFileCountOfClassification(String classification)
    {
        File classDir 
=   new  File(traningTextDir.getPath()  + File.separator  + classification);
        
return  classDir.list().length;
    }
    
/**
    * 返回给定分类中包含关键字/词的训练文本的数目
    * 
@param  classification 给定的分类
    * 
@param  key 给定的关键字/词
    * 
@return  给定分类中包含关键字/词的训练文本的数目
    
*/
    
public   int  getCountContainKeyOfClassification(String classification,String key) 
    {
        
int  ret  =   0 ;
        
try  
        {
            String[] filePath 
=  getFilesPath(classification);
            
for  ( int  j  =   0 ; j  <  filePath.length; j ++
            {
                String text 
=  getText(filePath[j]);
                
if  (text.contains(key)) 
                {
                    ret
++ ;
                }
            }
        }
        
catch  (FileNotFoundException ex) 
        {
        Logger.getLogger(TrainingDataManager.
class .getName()).log(Level.SEVERE,  null ,ex);
    
        } 
        
catch  (IOException ex)
        {
            Logger.getLogger(TrainingDataManager.
class .getName()).log(Level.SEVERE,  null ,ex);
        }
        
return  ret;
    }
}
复制代码

先验概率

      先验概率是我们需要计算的两大概率值之一

复制代码
package  com.vista;
/**
* 先验概率计算
* <h3>先验概率计算</h3>
* P(c<sub>j</sub>)=N(C=c<sub>j</sub>)<b>/</b>N <br>
* 其中,N(C=c<sub>j</sub>)表示类别c<sub>j</sub>中的训练文本数量;
* N表示训练文本集总数量。
*/
public   class  PriorProbability 
{
    
private   static  TrainingDataManager tdm  = new  TrainingDataManager();
    
/**
    * 先验概率
    * 
@param  c 给定的分类
    * 
@return  给定条件下的先验概率
    
*/
    
public   static   float  calculatePc(String c)
    {
        
float  ret  =  0F;
        
float  Nc  =  tdm.getTrainingFileCountOfClassification(c);
        
float  N  =  tdm.getTrainingFileCount();
        ret 
=  Nc  /  N;
        
return  ret;
    }
}

复制代码

分类条件概率

      这是另一个影响因子,和先验概率一起来决定最终结果

复制代码
package  com.vista;

/**
* <b>类</b>条件概率计算
*
* <h3>类条件概率</h3>
* P(x<sub>j</sub>|c<sub>j</sub>)=( N(X=x<sub>i</sub>, C=c<sub>j
* </sub>)+1 ) <b>/</b> ( N(C=c<sub>j</sub>)+M+V ) <br>
* 其中,N(X=x<sub>i</sub>, C=c<sub>j</sub>)表示类别c<sub>j</sub>中包含属性x<sub>
* i</sub>的训练文本数量;N(C=c<sub>j</sub>)表示类别c<sub>j</sub>中的训练文本数量;M值用于避免
* N(X=x<sub>i</sub>, C=c<sub>j</sub>)过小所引发的问题;V表示类别的总数。
*
* <h3>条件概率</h3>
* <b>定义</b> 设A, B是两个事件,且P(A)>0 称<br>
* <tt>P(B∣A)=P(AB)/P(A)</tt><br>
* 为在条件A下发生的条件事件B发生的条件概率。

*/

public   class  ClassConditionalProbability 
{
    
private   static  TrainingDataManager tdm  =   new  TrainingDataManager();
    
private   static   final   float  M  =  0F;
    
    
/**
    * 计算类条件概率
    * 
@param  x 给定的文本属性
    * 
@param  c 给定的分类
    * 
@return  给定条件下的类条件概率
    
*/
    
public   static   float  calculatePxc(String x, String c) 
    {
        
float  ret  =  0F;
        
float  Nxc  =  tdm.getCountContainKeyOfClassification(c, x);
        
float  Nc  =  tdm.getTrainingFileCountOfClassification(c);
        
float  V  =  tdm.getTraningClassifications().length;
        ret 
=  (Nxc  +   1 /  (Nc  +  M  +  V);  // 为了避免出现0这样极端情况,进行加权处理
         return  ret;
    }
}
复制代码

分类结果

      用来保存各个分类及其计算出的概率值,

复制代码
package  com.vista;
/**
* 分类结果
*/
public   class  ClassifyResult 
{
    
public   double  probility; // 分类的概率
     public  String classification; // 分类
     public  ClassifyResult()
    {
        
this .probility  =   0 ;
        
this .classification  =   null ;
    }
}
复制代码

朴素贝叶斯分类器

      利用样本数据集计算先验概率和各个文本向量属性在分类中的条件概率,从而计算出各个概率值,最后对各个概率值进行排序,选出最大的概率值,即为所属的分类。

复制代码
package  com.vista;
import  com.vista.ChineseSpliter;
import  com.vista.ClassConditionalProbability;
import  com.vista.PriorProbability;
import  com.vista.TrainingDataManager;
import  com.vista.StopWordsHandler;
import  java.util.ArrayList;
import  java.util.Comparator;
import  java.util.List;
import  java.util.Vector;

/**
* 朴素贝叶斯分类器
*/
public   class  BayesClassifier 
{
    
private  TrainingDataManager tdm; // 训练集管理器
     private  String trainnigDataPath; // 训练集路径
     private   static   double  zoomFactor  =   10.0f ;
    
/**
    * 默认的构造器,初始化训练集
    
*/
    
public  BayesClassifier() 
    {
        tdm 
= new  TrainingDataManager();
    }

    
/**
    * 计算给定的文本属性向量X在给定的分类Cj中的类条件概率
    * <code>ClassConditionalProbability</code>连乘值
    * 
@param  X 给定的文本属性向量
    * 
@param  Cj 给定的类别
    * 
@return  分类条件概率连乘值,即<br>
    
*/
    
float  calcProd(String[] X, String Cj) 
    {
        
float  ret  =   1.0F ;
        
//  类条件概率连乘
         for  ( int  i  =   0 ; i  < X.length; i ++ )
        {
            String Xi 
=  X[i];
            
// 因为结果过小,因此在连乘之前放大10倍,这对最终结果并无影响,因为我们只是比较概率大小而已
            ret  *= ClassConditionalProbability.calculatePxc(Xi, Cj) * zoomFactor;
        }
        
//  再乘以先验概率
        ret  *=  PriorProbability.calculatePc(Cj);
        
return  ret;
    }
    
/**
    * 去掉停用词
    * 
@param  text 给定的文本
    * 
@return  去停用词后结果
    
*/
    
public  String[] DropStopWords(String[] oldWords)
    {
        Vector
< String >  v1  =   new  Vector < String > ();
        
for ( int  i = 0 ;i < oldWords.length; ++ i)
        {
            
if (StopWordsHandler.IsStopWord(oldWords[i]) == false )
            {
// 不是停用词
                v1.add(oldWords[i]);
            }
        }
        String[] newWords 
=   new  String[v1.size()];
        v1.toArray(newWords);
        
return  newWords;
    }
    
/**
    * 对给定的文本进行分类
    * 
@param  text 给定的文本
    * 
@return  分类结果
    
*/
    @SuppressWarnings(
" unchecked " )
    
public  String classify(String text) 
    {
        String[] terms 
=   null ;
        terms
=  ChineseSpliter.split(text,  "   " ).split( "   " ); // 中文分词处理(分词后结果可能还包含有停用词)
        terms  =  DropStopWords(terms); // 去掉停用词,以免影响分类
        
        String[] Classes 
=  tdm.getTraningClassifications(); // 分类
         float  probility  =   0.0F ;
        List
< ClassifyResult >  crs  =   new  ArrayList < ClassifyResult > (); // 分类结果
         for  ( int  i  =   0 ; i  < Classes.length; i ++
        {
            String Ci 
=  Classes[i]; // 第i个分类
            probility  =  calcProd(terms, Ci); // 计算给定的文本属性向量terms在给定的分类Ci中的分类条件概率
            
// 保存分类结果
            ClassifyResult cr  =   new  ClassifyResult();
            cr.classification 
=  Ci; // 分类
            cr.probility  =  probility; // 关键字在分类的条件概率
            System.out.println( " In process. " );
            System.out.println(Ci 
+   " "   +  probility);
            crs.add(cr);
        }
        
// 对最后概率结果进行排序
        java.util.Collections.sort(crs, new  Comparator() 
        {
            
public   int  compare( final  Object o1, final  Object o2) 
            {
                
final  ClassifyResult m1  =  (ClassifyResult) o1;
                
final  ClassifyResult m2  =  (ClassifyResult) o2;
                
final   double  ret  =  m1.probility  -  m2.probility;
                
if  (ret  <   0
                {
                    
return   1 ;
                } 
                
else  
                {
                    
return   - 1 ;
                }
            }
        });
        
// 返回概率最大的分类
         return  crs.get( 0 ).classification;
    }
    
    
public   static   void  main(String[] args)
    {
        String text 
=   " 微软公司提出以446亿美元的价格收购雅虎中国网2月1日报道 美联社消息,微软公司提出以446亿美元现金加股票的价格收购搜索网站雅虎公司。微软提出以每股31美元的价格收购雅虎。微软的收购报价较雅虎1月31日的收盘价19.18美元溢价62%。微软公司称雅虎公司的股东可以选择以现金或股票进行交易。微软和雅虎公司在2006年底和2007年初已在寻求双方合作。而近两年,雅虎一直处于困境:市场份额下滑、运营业绩不佳、股价大幅下跌。对于力图在互联网市场有所作为的微软来说,收购雅虎无疑是一条捷径,因为双方具有非常强的互补性。(小桥) " ;
        BayesClassifier classifier 
=   new  BayesClassifier(); // 构造Bayes分类器
        String result  =  classifier.classify(text); // 进行分类
        System.out.println( " 此项属于[ " + result + " ] " );
    }
}
复制代码

训练集与分类测试

作为测试,这里选用Sogou实验室的文本分类数据,我只使用了mini版本。迷你版本有10个类别 ,共计100篇文章,总大小244KB

使用的测试文本:

复制代码
微软公司提出以446亿美元的价格收购雅虎

中国网2月1日报道 美联社消息,微软公司提出以446亿美元现金加股票的价格收购搜索网站雅虎公司。

微软提出以每股31美元的价格收购雅虎。微软的收购报价较雅虎1月31日的收盘价19
. 18美元溢价62%。微软公司称雅虎公司的股东可以选择以现金或股票进行交易。

微软和雅虎公司在2006年底和2007年初已在寻求双方合作。而近两年,雅虎一直处于困境:市场份额下滑、运营业绩不佳、股价大幅下跌。对于力图在互联网市场有所作为的微软来说,收购雅虎无疑是一条捷径,因为双方具有非常强的互补性。
( 小桥 )
复制代码

使用mini版本的测试结果:

复制代码
In process .
IT:
2.8119528E-5
In process
.
体育:
2.791735E-21
In process
.
健康:
3.3188528E-12
In process
.
军事:
2.532662E-19
In process
.
招聘:
2.3753596E-17
In process
.
教育:
4.2023427E-19
In process
.
文化:
6.0595915E-23
In process
.
旅游:
5.1286412E-17
In process
.
汽车:
4.085446E-8
In process
.
财经:
3.7337095E-10
此项属于[IT]
复制代码

作者:洞庭散人

出处:http://phinecos.cnblogs.com/    

本博客遵从 Creative Commons Attribution 3.0 License,若用于非商业目的,您可以自由转载,但请保留原作者信息和文章链接URL。
 

OpenCV机器学习(1):贝叶斯分类器实现代码分析

分类: OpenCV 机器学习   214人阅读  评论(5)  收藏  举报

目录(?)[+]

OpenCV的机器学习类定义在ml.hpp文件中,基础类是CvStatModel,其他各种分类器从这里继承而来。

今天研究CvNormalBayesClassifier分类器。

1.类定义

在ml.hpp中有以下类定义:

  1. class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel  
  2. {  
  3. public:  
  4.     CV_WRAP CvNormalBayesClassifier();  
  5.     virtual ~CvNormalBayesClassifier();  
  6.   
  7.     CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,  
  8.         const CvMat* varIdx=0, const CvMat* sampleIdx=0 );  
  9.   
  10.     virtual bool train( const CvMat* trainData, const CvMat* responses,  
  11.         const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );  
  12.   
  13.     virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;  
  14.     CV_WRAP virtual void clear();  
  15.   
  16.     CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,  
  17.                             const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );  
  18.     CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,  
  19.                        const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),  
  20.                        bool update=false );  
  21.     CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;  
  22.   
  23.     virtual void write( CvFileStorage* storage, const char* name ) const;  
  24.     virtual void read( CvFileStorage* storage, CvFileNode* node );  
  25.   
  26. protected:  
  27.     int     var_count, var_all;  
  28.     CvMat*  var_idx;  
  29.     CvMat*  cls_labels;  
  30.     CvMat** count;  
  31.     CvMat** sum;  
  32.     CvMat** productsum;  
  33.     CvMat** avg;  
  34.     CvMat** inv_eigen_values;  
  35.     CvMat** cov_rotate_mats;  
  36.     CvMat*  c;  
  37. };  

2.示例

此类使用方法如下:(引用别人的代码,忘记出处了,非常抱歉这个。。。)

  1. //openCV中贝叶斯分类器的API函数用法举例  
  2. //运行环境:win7 + VS2005 + openCV2.4.5  
  3.   
  4. #include "global_include.h"  
  5.   
  6. using namespace std;  
  7. using namespace cv;  
  8.   
  9. //10个样本特征向量维数为12的训练样本集,第一列为该样本的类别标签  
  10. double inputArr[10][13] =   
  11. {  
  12.      1,0.708333,1,1,-0.320755,-0.105023,-1,1,-0.419847,-1,-0.225806,0,1,   
  13.     -1,0.583333,-1,0.333333,-0.603774,1,-1,1,0.358779,-1,-0.483871,0,-1,  
  14.      1,0.166667,1,-0.333333,-0.433962,-0.383562,-1,-1,0.0687023,-1,-0.903226,-1,-1,  
  15.     -1,0.458333,1,1,-0.358491,-0.374429,-1,-1,-0.480916,1,-0.935484,0,-0.333333,  
  16.     -1,0.875,-1,-0.333333,-0.509434,-0.347032,-1,1,-0.236641,1,-0.935484,-1,-0.333333,  
  17.     -1,0.5,1,1,-0.509434,-0.767123,-1,-1,0.0534351,-1,-0.870968,-1,-1,  
  18.      1,0.125,1,0.333333,-0.320755,-0.406393,1,1,0.0839695,1,-0.806452,0,-0.333333,  
  19.      1,0.25,1,1,-0.698113,-0.484018,-1,1,0.0839695,1,-0.612903,0,-0.333333,  
  20.      1,0.291667,1,1,-0.132075,-0.237443,-1,1,0.51145,-1,-0.612903,0,0.333333,  
  21.      1,0.416667,-1,1,0.0566038,0.283105,-1,1,0.267176,-1,0.290323,0,1  
  22. };  
  23.   
  24. //一个测试样本的特征向量  
  25. double testArr[]=  
  26. {  
  27.     0.25,1,1,-0.226415,-0.506849,-1,-1,0.374046,-1,-0.83871,0,-1  
  28. };  
  29.   
  30.   
  31. int _tmain(int argc, _TCHAR* argv[])  
  32. {  
  33.     Mat trainData(10, 12, CV_32FC1);//构建训练样本的特征向量  
  34.     for (int i=0; i<10; i++)  
  35.     {  
  36.         for (int j=0; j<12; j++)  
  37.         {  
  38.             trainData.at<float>(i, j) = inputArr[i][j+1];  
  39.         }  
  40.     }  
  41.   
  42.     Mat trainResponse(10, 1, CV_32FC1);//构建训练样本的类别标签  
  43.     for (int i=0; i<10; i++)  
  44.     {  
  45.         trainResponse.at<float>(i, 0) = inputArr[i][0];  
  46.     }  
  47.   
  48.     CvNormalBayesClassifier nbc;  
  49.     bool trainFlag = nbc.train(trainData, trainResponse);//进行贝叶斯分类器训练  
  50.     if (trainFlag)  
  51.     {  
  52.         cout<<"train over..."<<endl;  
  53.         nbc.save("normalBayes.txt");  
  54.     }  
  55.     else  
  56.     {  
  57.         cout<<"train error..."<<endl;  
  58.         system("pause");  
  59.         exit(-1);  
  60.     }  
  61.   
  62.   
  63.     CvNormalBayesClassifier testNbc;  
  64.     testNbc.load("normalBayes.txt");  
  65.   
  66.     Mat testSample(1, 12, CV_32FC1);//构建测试样本  
  67.     for (int i=0; i<12; i++)  
  68.     {  
  69.         testSample.at<float>(0, i) = testArr[i];  
  70.     }  
  71.   
  72.     float flag = testNbc.predict(testSample);//进行测试  
  73.     cout<<"flag = "<<flag<<endl;  
  74.   
  75.     system("pause");  
  76.     return 0;  
  77. }  

3.步骤

两步走:

1.调用train函数训练分类器;

2.调用predict函数,判定测试样本的类别。

以上示例代码还延时了怎样使用save和load函数,使得训练好的分类器可以保存在文本中。

4.初始化

接下来,看CvNormalBayesClassifier类的无参数初始化:

  1. CvNormalBayesClassifier::CvNormalBayesClassifier()  
  2. {  
  3.     var_count = var_all = 0;  
  4.     var_idx = 0;  
  5.     cls_labels = 0;  
  6.     count = 0;  
  7.     sum = 0;  
  8.     productsum = 0;  
  9.     avg = 0;  
  10.     inv_eigen_values = 0;  
  11.     cov_rotate_mats = 0;  
  12.     c = 0;  
  13.     default_model_name = "my_nb";  
  14. }  
还有另一种带参数的初始化形式:
  1. CvNormalBayesClassifier::CvNormalBayesClassifier(  
  2.     const CvMat* _train_data, const CvMat* _responses,  
  3.     const CvMat* _var_idx, const CvMat* _sample_idx )  
  4. {  
  5.     var_count = var_all = 0;  
  6.     var_idx = 0;  
  7.     cls_labels = 0;  
  8.     count = 0;  
  9.     sum = 0;  
  10.     productsum = 0;  
  11.     avg = 0;  
  12.     inv_eigen_values = 0;  
  13.     cov_rotate_mats = 0;  
  14.     c = 0;  
  15.     default_model_name = "my_nb";  
  16.   
  17.     train( _train_data, _responses, _var_idx, _sample_idx );  
  18. }  
可见,带参数形式糅合了类的初始化和train函数。

另外,以Mat参数形式的对应函数版本,功能是一致的,只不过为了体现2.0以后版本的C++特性罢了。如下:

  1. CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,  
  2.                         const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );  
  3. CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,  
  4.                    const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),  
  5.                    bool update=false );  
  6. CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;  

5.训练

下面开始分析train函数,分析CvMat格式参数的train函数,即:

  1. bool train( const CvMat* trainData, const CvMat* responses,const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );  

在进入该函数之前,还要先回头看看CvNormalBayesClassifier类有哪些数据成员:

  1. protected:  
  2.     int     var_count, var_all; //每个样本的特征维数、即变量数目,或者说trainData的列数目(在varIdx=0时)  
  3.     CvMat*  var_idx;        //特征子集的索引,可能特征数目为100,但是只用其中一部分训练  
  4.     CvMat*  cls_labels;     //类别数目  
  5.     CvMat** count;      //count[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的数目  
  6.     CvMat** sum;        //sum[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的累加和  
  7.     CvMat** productsum;     //productsum[0...(classNum-1)],每个元素是一个CvMat(rows=cols=var_count)指针,存储类内特征相关矩阵  
  8.     CvMat** avg;        //avg[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的平均值  
  9.     CvMat** inv_eigen_values;//inv_eigen_values[0...(classNum-1)],每个元素是一个CvMat(rows=1,cols=var_count)指针,代表训练数据中每一类的某个特征的特征值的倒数  
  10.     CvMat** cov_rotate_mats;    //特征变量的协方差矩阵经过SVD奇异值分解后得到的特征向量矩阵  
  11.     CvMat*  c;  

这些数据成员,怎样使用呢?在train函数中见分晓:

  1. bool CvNormalBayesClassifier::train( const CvMat* _train_data, const CvMat* _responses,  
  2.                                     const CvMat* _var_idx, const CvMat* _sample_idx, bool update )  
  3. {  
  4.     const float min_variation = FLT_EPSILON;  
  5.     bool result = false;  
  6.     CvMat* responses   = 0;  
  7.     const float** train_data = 0;  
  8.     CvMat* __cls_labels = 0;  
  9.     CvMat* __var_idx = 0;  
  10.     CvMat* cov = 0;  
  11.   
  12.     CV_FUNCNAME( "CvNormalBayesClassifier::train" );  
  13.   
  14.     __BEGIN__;  
  15.   
  16.     int cls, nsamples = 0, _var_count = 0, _var_all = 0, nclasses = 0;  
  17.     int s, c1, c2;  
  18.     const int* responses_data;  
  19.   
  20.     //1.整理训练数据  
  21.     CV_CALL( cvPrepareTrainData( 0,  
  22.         _train_data, CV_ROW_SAMPLE, _responses, CV_VAR_CATEGORICAL,  
  23.         _var_idx, _sample_idx, false, &train_data,  
  24.         &nsamples, &_var_count, &_var_all, &responses,  
  25.         &__cls_labels, &__var_idx ));  
  26.   
  27.     if( !update )   //如果是初始训练数据  
  28.     {  
  29.         const size_t mat_size = sizeof(CvMat*);  
  30.         size_t data_size;  
  31.   
  32.         clear();  
  33.   
  34.         var_idx = __var_idx;  
  35.         cls_labels = __cls_labels;  
  36.         __var_idx = __cls_labels = 0;  
  37.         var_count = _var_count;  
  38.         var_all = _var_all;  
  39.   
  40.         nclasses = cls_labels->cols;  
  41.         data_size = nclasses*6*mat_size;  
  42.   
  43.         CV_CALL( count = (CvMat**)cvAlloc( data_size ));  
  44.         memset( count, 0, data_size );          //count[cls]存储第cls类每个属性变量个数  
  45.                                         
  46.         sum             = count      + nclasses;//sum[cls]存储第cls类每个属性取值的累加和  
  47.         productsum      = sum        + nclasses;//productsum[cls]存储第cls类的协方差矩阵的乘积项sum(XiXj),cov(Xi,Xj)=sum(XiXj)-sum(Xi)E(Xj)  
  48.         avg             = productsum + nclasses;//avg[cls]存储第cls类的每个变量均值  
  49.         inv_eigen_values= avg        + nclasses;//inv_eigen_values[cls]存储第cls类的协方差矩阵的特征值  
  50.         cov_rotate_mats = inv_eigen_values         + nclasses;//存储第cls类的矩阵的特征值对应的特征向量  
  51.   
  52.         CV_CALL( c = cvCreateMat( 1, nclasses, CV_64FC1 ));  
  53.           
  54.         for( cls = 0; cls < nclasses; cls++ )    //对所有类别  
  55.         {  
  56.             CV_CALL(count[cls]            = cvCreateMat( 1, var_count, CV_32SC1 ));  
  57.             CV_CALL(sum[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));  
  58.             CV_CALL(productsum[cls]       = cvCreateMat( var_count, var_count, CV_64FC1 ));  
  59.             CV_CALL(avg[cls]              = cvCreateMat( 1, var_count, CV_64FC1 ));  
  60.             CV_CALL(inv_eigen_values[cls] = cvCreateMat( 1, var_count, CV_64FC1 ));  
  61.             CV_CALL(cov_rotate_mats[cls]  = cvCreateMat( var_count, var_count, CV_64FC1 ));  
  62.             CV_CALL(cvZero( count[cls] ));  
  63.             CV_CALL(cvZero( sum[cls] ));  
  64.             CV_CALL(cvZero( productsum[cls] ));  
  65.             CV_CALL(cvZero( avg[cls] ));  
  66.             CV_CALL(cvZero( inv_eigen_values[cls] ));  
  67.             CV_CALL(cvZero( cov_rotate_mats[cls] ));  
  68.         }  
  69.     }  
  70.     else    //如果是更新训练数据  
  71.     {  
  72.         // check that the new training data has the same dimensionality etc.  
  73.         if( _var_count != var_count || _var_all != var_all || !((!_var_idx && !var_idx) ||  
  74.             (_var_idx && var_idx && cvNorm(_var_idx,var_idx,CV_C) < DBL_EPSILON)) )  
  75.             CV_ERROR( CV_StsBadArg,  
  76.             "The new training data is inconsistent with the original training data" );  
  77.   
  78.         if( cls_labels->cols != __cls_labels->cols ||  
  79.             cvNorm(cls_labels, __cls_labels, CV_C) > DBL_EPSILON )  
  80.             CV_ERROR( CV_StsNotImplemented,  
  81.             "In the current implementation the new training data must have absolutely "  
  82.             "the same set of class labels as used in the original training data" );  
  83.   
  84.         nclasses = cls_labels->cols;  
  85.     }  
  86.   
  87.     responses_data = responses->data.i;  
  88.     CV_CALL( cov = cvCreateMat( _var_count, _var_count, CV_64FC1 ));  
  89.   
  90.     //2.处理训练数据,计算每一类的  
  91.     // process train data (count, sum , productsum)   
  92.     for( s = 0; s < nsamples; s++ )  
  93.     {  
  94.         cls = responses_data[s];  
  95.         int* count_data = count[cls]->data.i;  
  96.         double* sum_data = sum[cls]->data.db;  
  97.         double* prod_data = productsum[cls]->data.db;  
  98.         const float* train_vec = train_data[s];  
  99.   
  100.         for( c1 = 0; c1 < _var_count; c1++, prod_data += _var_count )  
  101.         {  
  102.             double val1 = train_vec[c1];  
  103.             sum_data[c1] += val1;  
  104.             count_data[c1]++;  
  105.             for( c2 = c1; c2 < _var_count; c2++ )  
  106.                 prod_data[c2] += train_vec[c2]*val1;  
  107.         }  
  108.     }  
  109.   
  110.     //计算每一类的每个属性平均值、协方差矩阵  
  111.     // calculate avg, covariance matrix, c  
  112.     for( cls = 0; cls < nclasses; cls++ )    //对每一类  
  113.     {  
  114.         double det = 1;  
  115.         int i, j;  
  116.         CvMat* w = inv_eigen_values[cls];  
  117.         int* count_data = count[cls]->data.i;  
  118.         double* avg_data = avg[cls]->data.db;  
  119.         double* sum1 = sum[cls]->data.db;  
  120.   
  121.         cvCompleteSymm( productsum[cls], 0 );  
  122.   
  123.         for( j = 0; j < _var_count; j++ )    //计算当前类别cls的每个变量属性值的平均值  
  124.         {  
  125.             int n = count_data[j];  
  126.             avg_data[j] = n ? sum1[j] / n : 0.;  
  127.         }  
  128.   
  129.         count_data = count[cls]->data.i;  
  130.         avg_data = avg[cls]->data.db;  
  131.         sum1 = sum[cls]->data.db;  
  132.   
  133.         for( i = 0; i < _var_count; i++ )//计算当前类别cls的变量协方差矩阵,矩阵大小为_var_count * _var_count,注意协方差矩阵对称。  
  134.         {  
  135.             double* avg2_data = avg[cls]->data.db;  
  136.             double* sum2 = sum[cls]->data.db;  
  137.             double* prod_data = productsum[cls]->data.db + i*_var_count;  
  138.             double* cov_data = cov->data.db + i*_var_count;  
  139.             double s1val = sum1[i];  
  140.             double avg1 = avg_data[i];  
  141.             int _count = count_data[i];  
  142.   
  143.             for( j = 0; j <= i; j++ )  
  144.             {  
  145.                 double avg2 = avg2_data[j];  
  146.                 double cov_val = prod_data[j] - avg1 * sum2[j] - avg2 * s1val + avg1 * avg2 * _count;  
  147.                 cov_val = (_count > 1) ? cov_val / (_count - 1) : cov_val;  
  148.                 cov_data[j] = cov_val;  
  149.             }  
  150.         }  
  151.   
  152.         CV_CALL( cvCompleteSymm( cov, 1 ));  
  153.         CV_CALL( cvSVD( cov, w, cov_rotate_mats[cls], 0, CV_SVD_U_T ));  
  154.         CV_CALL( cvMaxS( w, min_variation, w ));  
  155.         for( j = 0; j < _var_count; j++ )  
  156.             det *= w->data.db[j];  
  157.   
  158.         CV_CALL( cvDiv( NULL, w, w ));  
  159.         c->data.db[cls] = det > 0 ? log(det) : -700;  
  160.     }  
  161.   
  162.     result = true;  
  163.   
  164.     __END__;  
  165.   
  166.     if( !result || cvGetErrStatus() < 0 )  
  167.         clear();  
  168.   
  169.     cvReleaseMat( &cov );  
  170.     cvReleaseMat( &__cls_labels );  
  171.     cvReleaseMat( &__var_idx );  
  172.     cvFree( &train_data );  
  173.   
  174.     return result;  
  175. }  
训练部分就此完成。

6.预测

下面看用于预测的predict函数的实现代码:

  1. float CvNormalBayesClassifier::predict( const CvMat* samples, CvMat* results ) const  
  2. {  
  3.     float value = 0;  
  4.   
  5.     if( !CV_IS_MAT(samples) || CV_MAT_TYPE(samples->type) != CV_32FC1 || samples->cols != var_all )  
  6.         CV_Error( CV_StsBadArg,  
  7.         "The input samples must be 32f matrix with the number of columns = var_all" );  
  8.   
  9.     if( samples->rows > 1 && !results )  
  10.         CV_Error( CV_StsNullPtr,  
  11.         "When the number of input samples is >1, the output vector of results must be passed" );  
  12.   
  13.     if( results )  
  14.     {  
  15.         if( !CV_IS_MAT(results) || (CV_MAT_TYPE(results->type) != CV_32FC1 &&  
  16.         CV_MAT_TYPE(results->type) != CV_32SC1) ||  
  17.         (results->cols != 1 && results->rows != 1) ||  
  18.         results->cols + results->rows - 1 != samples->rows )  
  19.         CV_Error( CV_StsBadArg, "The output array must be integer or floating-point vector "  
  20.         "with the number of elements = number of rows in the input matrix" );  
  21.     }  
  22.   
  23.     const int* vidx = var_idx ? var_idx->data.i : 0;  
  24.   
  25.     cv::parallel_for(cv::BlockedRange(0, samples->rows), predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples,  
  26.                                                                       vidx, cls_labels, results, &value, var_count  
  27.     ));  
  28.   
  29.     return value;  
  30. }  
可以发现,预测部分核心代码是:
  1. cv::parallel_for(cv::BlockedRange(0, samples->rows), predict_body(c, cov_rotate_mats, inv_eigen_values, avg, samples,  
  2.                                                                       vidx, cls_labels, results, &value, var_count));  
parallel_for是用于并行支持的,可能会调用tbb模块。predict_body则是一个结构体,内部的()符号被重载,实现预测功能。其完整定义如下:

  1. //predict函数调用predict_body结构体的()符号重载函数,实现基于贝叶斯的分类  
  2. struct predict_body   
  3. {  
  4.     predict_body(CvMat* _c, CvMat** _cov_rotate_mats, CvMat** _inv_eigen_values, CvMat** _avg,  
  5.                 const CvMat* _samples, const int* _vidx, CvMat* _cls_labels,  
  6.                 CvMat* _results, float* _value, int _var_count1)  
  7.     {  
  8.         c = _c;  
  9.         cov_rotate_mats = _cov_rotate_mats;  
  10.         inv_eigen_values = _inv_eigen_values;  
  11.         avg = _avg;  
  12.         samples = _samples;  
  13.         vidx = _vidx;  
  14.         cls_labels = _cls_labels;  
  15.         results = _results;  
  16.         value = _value;  
  17.         var_count1 = _var_count1;  
  18.     }  
  19.   
  20.     CvMat* c;  
  21.     CvMat** cov_rotate_mats;  
  22.     CvMat** inv_eigen_values;  
  23.     CvMat** avg;  
  24.     const CvMat* samples;  
  25.     const int* vidx;  
  26.     CvMat* cls_labels;  
  27.   
  28.     CvMat* results;  
  29.     float* value;  
  30.     int var_count1;  
  31.   
  32.     void operator()( const cv::BlockedRange& range ) const  
  33.     {  
  34.   
  35.         int cls = -1;  
  36.         int rtype = 0, rstep = 0;  
  37.         int nclasses = cls_labels->cols;  
  38.         int _var_count = avg[0]->cols;  
  39.   
  40.         if (results)  
  41.         {  
  42.             rtype = CV_MAT_TYPE(results->type);  
  43.             rstep = CV_IS_MAT_CONT(results->type) ? 1 : results->step/CV_ELEM_SIZE(rtype);  
  44.         }  
  45.         // allocate memory and initializing headers for calculating  
  46.         cv::AutoBuffer<double> buffer(nclasses + var_count1);  
  47.         CvMat diff = cvMat( 1, var_count1, CV_64FC1, &buffer[0] );  
  48.   
  49.         for(int k = range.begin(); k < range.end(); k += 1 )//对于每个输入测试样本  
  50.         {  
  51.             int ival;  
  52.             double opt = FLT_MAX;  
  53.   
  54.             for(int i = 0; i < nclasses; i++ )   //对于每一类别,计算其似然概率  
  55.             {  
  56.   
  57.                 double cur = c->data.db[i];  
  58.                 CvMat* u = cov_rotate_mats[i];  
  59.                 CvMat* w = inv_eigen_values[i];  
  60.   
  61.                 const double* avg_data = avg[i]->data.db;  
  62.                 const float* x = (const float*)(samples->data.ptr + samples->step*k);  
  63.   
  64.                 // cov = u w u'  -->  cov^(-1) = u w^(-1) u'  
  65.                 for(int j = 0; j < _var_count; j++ ) //计算特征相对于均值的偏移  
  66.                     diff.data.db[j] = avg_data[j] - x[vidx ? vidx[j] : j];  
  67.   
  68.                 cvGEMM( &diff, u, 1, 0, 0, &diff, CV_GEMM_B_T );  
  69.                 for(int j = 0; j < _var_count; j++ )//计算特征的联合概率  
  70.                 {  
  71.                     double d = diff.data.db[j];  
  72.                     cur += d*d*w->data.db[j];  
  73.                 }  
  74.   
  75.                 if( cur < opt )  //找到分类概率最大的  
  76.                 {  
  77.                     cls = i;  
  78.                     opt = cur;  
  79.                 }  
  80.                 // probability = exp( -0.5 * cur )   
  81.   
  82.             }//for(int i = 0; i < nclasses; i++ )  
  83.   
  84.             ival = cls_labels->data.i[cls];  
  85.             if( results )  
  86.             {  
  87.                 if( rtype == CV_32SC1 )  
  88.                     results->data.i[k*rstep] = ival;  
  89.                 else  
  90.                     results->data.fl[k*rstep] = (float)ival;  
  91.             }  
  92.             if( k == 0 )  
  93.                 *value = (float)ival;  
  94.   
  95.         }//for(int k = range.begin()...  
  96.   
  97.     }//void operator()...  
  98. };  
好啦,预测部分至此完成。

但有一个小小疑问:好像在predict部分实现代码中没有看到先验概率参与到计算当中,而贝叶斯估计是应该p(w|x)=p(w)*p(x|w)/...的呀,但是这里只看到了计算p(x|w)的部分。没有p(w)的身影,不知道为何,盼高人指点。

贝叶斯代码分析完成。


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值