Naive Bayes 朴素贝叶斯的JAVA代码实现

JAVA工程文件可在http://download.csdn.net/detail/u011321908/6385295下载 下面贴的代码仅是主类程序


1.关于贝叶斯分类

bayes 是一种统计学分类方法,它基于贝叶斯定理,它假定一个属性值对给定类的影响独立于其它属性点的值。该假定称作类条件独立。做次假定是为了简化所需计算,并在此意义下称为“朴素的”。

bayes分类的算法大致如下:

(1)对于属性值是离散的,并且目标label值也是离散的情况下。分别计算label不同取值的概率,以及样本在label情况下的概率值,然后将这些概率值相乘最后得到一个概率的乘积,选择概率乘积最大的那个值对应的label值就为预测的结果。

例如以下:是预测苹果在给定属性的情况是甜还是不甜的情况:

color={0,1,2,3} weight={2,3,4};是属性序列,为离散型。sweet={yes,no}是目标值,也为离散型;

这时我们要预测在color=3,weight=3的情况下的目标值,计算过程如下:

P{y=yes}=2/5=0.4; P{color=3|yes}=1/2=0.5;P{weight=3|yes}=1/2=0.5;   故F{color=3,weight=3}取yesd的概率为 0.4*0.5*0.5=0.1;

P{y=no}=3/5=0.6; P{color=3|no}=1/3 P{weight=3|no}=1/3;  故P{color=3,weight=3}取no为 0.6*1/3*1/3=1/15;

0.1>1/15 所以认为 F{color=3,weight=3}=yes;

(2)对于属性值是连续的情况,思想和离散是相同的,只是这时候我们计算属性的概率用的是高斯密度:


这里的Xk就是样本的取值,u是样本所在列的均值,kesi是标准差;

最后代码如下:

/*
 * To change this template, choose Tools | Templates
 * and open the template in the editor.
 */
package auxiliary;

import java.util.ArrayList;

/**
 *
 * @author Michael Kong
 */
public class NaiveBayes extends Classifier {

	boolean isClassfication[];
   ArrayList 
   
   
    
    lblClass=new ArrayList
    
    
     
     ();  //存储目标值的种类
   ArrayList
     
     
      
      lblCount=new ArrayList
      
      
       
       ();//存储目标值的个数
   ArrayList
       
       
         lblProba=new ArrayList 
        
          ();//存储对应的label的概率 CountProbility countlblPro; /*@ClassListBasedLabel是将训练数组按照 label的顺序来分类存储*/ ArrayList 
          
           
           
             >> ClassListBasedLabel=new ArrayList 
             
              
              
                >> (); public NaiveBayes() { } @Override /** * @train主要完成求一些概率 * 1.labels中的不同取值的概率f(Yi); 对应28,29行两段代码 * 2.将训练数组按目标值分类存储 第37行代码 * */ public void train(boolean[] isCategory, double[][] features, double[] labels){ isClassfication=isCategory; countlblPro=new CountProbility(isCategory,features,labels); countlblPro.getlblClass(lblClass, lblCount, lblProba); ArrayList 
                
                
                  > trainingList=countlblPro.UnionFeaLbl(features, labels); //union the features[][] and labels[] ClassListBasedLabel=countlblPro.getClassListBasedLabel(lblClass, trainingList); } @Override /**3.在Y的条件下,计算Xi的概率 f(Xi/Y); * 4.返回使得Yi*Xi*...概率最大的那个label的取值 * */ public double predict(double[] features) { int max_index; //用于记录使概率取得最大的那个索引 int index=0; //这个索引是 标识不同的labels 所对应的概率 ArrayList 
                 
                   pro_=new ArrayList 
                  
                    (); //这个概率数组是存储features[] 在不同labels下对应的概率 for(ArrayList 
                    
                    
                      > elements: ClassListBasedLabel) //依次取不同的label值对应的元祖集合 { ArrayList 
                     
                       pro=new ArrayList 
                      
                        ();//存同一个label对应的所有概率,之后其中的元素自乘 double probility=1.0; //计算概率的乘积 for(int i=0;i 
                       
                         element:elements) //依次取labels中的所有元祖 { if(element.get(i).equals(features[i])) //如果这个元祖的第index数据和b相等,那么就count就加1 count++; } if(count==0) { pro.add(1/(double)(elements.size()+1)); } else pro.add(count/(double)elements.size()); //统计完所有之后 计算概率值 并加入 } else { double Sdev; double Mean; double probi=1.0; Mean=countlblPro.getMean(elements, i); Sdev=countlblPro.getSdev(elements, i); if(Sdev!=0) { probi*=((1/(Math.sqrt(2*Math.PI)*Sdev))*(Math.exp(-(features[i]-Mean)*(features[i]-Mean)/(2*Sdev*Sdev)))); pro.add(probi); } else pro.add(1.5); } } for(double pi:pro) probility*=pi; //将所有概率相乘 probility*=lblProba.get(index);//最后再乘以一个 Yi pro_.add(probility);// 放入pro_ 至此 一个循环结束, index++; } double max_pro=pro_.get(0); max_index=0; for(int i=1;i 
                        
                          =max_pro) { max_pro=pro_.get(i); max_index=i; } } return lblClass.get(max_index); } public class CountProbility { boolean []isCatory; double[][]features; private double[]labels; public CountProbility(boolean[] isCategory, double[][] features, double[] labels) { this.isCatory=isCategory; this.features=features; this.labels=labels; } //获取label中取值情况 public void getlblClass( ArrayList 
                         
                           lblClass,ArrayList 
                          
                            lblCount,ArrayList 
                           
                             lblProba) { int j=0; for(double i:labels) { //如果当前的label不存在于lblClass则加入 if(!lblClass.contains(i)) { lblClass.add(j,i); lblCount.add(j++,1); } else //如果label中已经存在,就将其计数加1 { int index=lblClass.indexOf(i); int count=lblCount.get(index); lblCount.set(index,++count); } } for(int i=0;i 
                            
                              > UnionFeaLbl(double[][] features, double[] labels) { ArrayList 
                              
                              
                                >traingList=new ArrayList 
                                
                                
                                  >(); for(int i=0;i 
                                 
                                   elements=new ArrayList 
                                  
                                    (); for(int j=0;j 
                                   
                                     >> getClassListBasedLabel (ArrayList 
                                    
                                      lblClass,ArrayList 
                                      
                                      
                                        >trainingList) { ArrayList 
                                        
                                         
                                         
                                           >> ClassListBasedLabel=new ArrayList 
                                           
                                            
                                            
                                              >> () ; for(double num:lblClass) { ArrayList 
                                              
                                              
                                                > elements=new ArrayList 
                                                
                                                
                                                  >(); for(ArrayList 
                                                 
                                                   element:trainingList) { if(element.get(element.size()-1).equals(num)) elements.add(element); } ClassListBasedLabel.add(elements); } return ClassListBasedLabel; } public double getMean(ArrayList 
                                                   
                                                   
                                                     > elements,int index) { double sum=0.0; double Mean; for(ArrayList 
                                                    
                                                      element:elements) { sum+=element.get(index); } Mean=sum/(double)elements.size(); return Mean; } public double getSdev(ArrayList 
                                                      
                                                      
                                                        > elements,int index) { double dev=0.0; double Mean; Mean=getMean(elements,index); for(ArrayList 
                                                       
                                                         element:elements) { dev+=Math.pow((element.get(index)-Mean),2); } dev=Math.sqrt(dev/elements.size()); return dev; } } } 
                                                        
                                                       
                                                      
                                                     
                                                    
                                                   
                                                  
                                                 
                                                
                                               
                                              
                                             
                                            
                                           
                                          
                                         
                                        
                                       
                                      
                                     
                                    
                                   
                                  
                                 
                                
                               
                              
                             
                            
                           
                          
                         
                        
                       
                      
                     
                    
                   
                  
                 
                
               
              
             
            
           
          
         
       
      
      
     
     
    
    
   
   




评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值