naive bayes java_Naive Bayes 朴素贝叶斯的JAVA代码实现

1.关于贝叶斯分类算法

bayes 是一种统计学分类方法,它基于贝叶斯定理,它假定一个属性值对给定类的影响独立于其它属性点的值。该假定称做类条件独立。作次假定是为了简化所需计算,并在此意义下称为“朴素的”。数组

bayes分类的算法大体以下:ide

(1)对于属性值是离散的,而且目标label值也是离散的状况下。分别计算label不一样取值的几率,以及样本在label状况下的几率值,而后将这些几率值相乘最后获得一个几率的乘积,选择几率乘积最大的那个值对应的label值就为预测的结果。this

例如如下:是预测苹果在给定属性的状况是甜仍是不甜的状况:spa

3a2d33fe470eeb04b6cde06fb5cea030.png

color={0,1,2,3} weight={2,3,4};是属性序列,为离散型。sweet={yes,no}是目标值,也为离散型;.net

这时咱们要预测在color=3,weight=3的状况下的目标值,计算过程以下:code

P{y=yes}=2/5=0.4;P{color=3|yes}=1/2=0.5;P{weight=3|yes}=1/2=0.5;   故F{color=3,weight=3}取yesd的几率为 0.4*0.5*0.5=0.1;blog

P{y=no}=3/5=0.6;P{color=3|no}=1/3 P{weight=3|no}=1/3;  故P{color=3,weight=3}取no为 0.6*1/3*1/3=1/15;索引

0.1>1/15 因此认为 F{color=3,weight=3}=yes;

(2)对于属性值是连续的状况,思想和离散是相同的,只是这时候咱们计算属性的几率用的是高斯密度:

209550363d53d8b135ebbc998ec04438.png

这里的Xk就是样本的取值,u是样本所在列的均值,kesi是标准差;

最后代码以下:

/*

* To change this template, choose Tools | Templates

* and open the template in the editor.

*/

package auxiliary;

import java.util.ArrayList;

/**

*

* @author Michael Kong

*/

public class NaiveBayes extends Classifier {

boolean isClassfication[];

ArrayList lblClass=new ArrayList(); //存储目标值的种类

ArrayListlblCount=new ArrayList();//存储目标值的个数

ArrayListlblProba=new ArrayList();//存储对应的label的几率

CountProbility countlblPro;

/*@ClassListBasedLabel是将训练数组按照 label的顺序来分类存储*/

ArrayList>> ClassListBasedLabel=new ArrayList>> ();

public NaiveBayes() {

}

@Override

/**

* @train主要完成求一些几率

* 1.labels中的不一样取值的几率f(Yi); 对应28,29行两段代码

* 2.将训练数组按目标值分类存储 第37行代码

* */

public void train(boolean[] isCategory, double[][] features, double[] labels){

isClassfication=isCategory;

countlblPro=new CountProbility(isCategory,features,labels);

countlblPro.getlblClass(lblClass, lblCount, lblProba);

ArrayList> trainingList=countlblPro.UnionFeaLbl(features, labels); //union the features[][] and labels[]

ClassListBasedLabel=countlblPro.getClassListBasedLabel(lblClass, trainingList);

}

@Override

/**3.在Y的条件下,计算Xi的几率 f(Xi/Y);

* 4.返回使得Yi*Xi*...几率最大的那个label的取值

* */

public double predict(double[] features) {

int max_index; //用于记录使几率取得最大的那个索引

int index=0; //这个索引是 标识不一样的labels 所对应的几率

ArrayList pro_=new ArrayList(); //这个几率数组是存储features[] 在不一样labels下对应的几率

for(ArrayList> elements: ClassListBasedLabel) //依次取不一样的label值对应的元祖集合

{

ArrayList pro=new ArrayList();//存同一个label对应的全部几率,以后其中的元素自乘

double probility=1.0; //计算几率的乘积

for(int i=0;i element:elements) //依次取labels中的全部元祖

{

if(element.get(i).equals(features[i])) //若是这个元祖的第index数据和b相等,那么就count就加1

count++;

}

if(count==0)

{

pro.add(1/(double)(elements.size()+1));

}

else

pro.add(count/(double)elements.size()); //统计完全部以后 计算几率值 并加入

}

else

{

double Sdev;

double Mean;

double probi=1.0;

Mean=countlblPro.getMean(elements, i);

Sdev=countlblPro.getSdev(elements, i);

if(Sdev!=0)

{

probi*=((1/(Math.sqrt(2*Math.PI)*Sdev))*(Math.exp(-(features[i]-Mean)*(features[i]-Mean)/(2*Sdev*Sdev))));

pro.add(probi);

}

else

pro.add(1.5);

}

}

for(double pi:pro)

probility*=pi; //将全部几率相乘

probility*=lblProba.get(index);//最后再乘以一个 Yi

pro_.add(probility);// 放入pro_ 至此 一个循环结束,

index++;

}

double max_pro=pro_.get(0);

max_index=0;

for(int i=1;i=max_pro)

{

max_pro=pro_.get(i);

max_index=i;

}

}

return lblClass.get(max_index);

}

public class CountProbility

{

boolean []isCatory;

double[][]features;

private double[]labels;

public CountProbility(boolean[] isCategory, double[][] features, double[] labels)

{

this.isCatory=isCategory;

this.features=features;

this.labels=labels;

}

//获取label中取值状况

public void getlblClass( ArrayList lblClass,ArrayListlblCount,ArrayListlblProba)

{

int j=0;

for(double i:labels)

{

//若是当前的label不存在于lblClass则加入

if(!lblClass.contains(i))

{

lblClass.add(j,i);

lblCount.add(j++,1);

}

else //若是label中已经存在,就将其计数加1

{

int index=lblClass.indexOf(i);

int count=lblCount.get(index);

lblCount.set(index,++count);

}

}

for(int i=0;i> UnionFeaLbl(double[][] features, double[] labels)

{

ArrayList>traingList=new ArrayList>();

for(int i=0;ielements=new ArrayList();

for(int j=0;j>> getClassListBasedLabel (ArrayList lblClass,ArrayList>trainingList)

{

ArrayList>> ClassListBasedLabel=new ArrayList>> () ;

for(double num:lblClass)

{

ArrayList> elements=new ArrayList>();

for(ArrayListelement:trainingList)

{

if(element.get(element.size()-1).equals(num))

elements.add(element);

}

ClassListBasedLabel.add(elements);

}

return ClassListBasedLabel;

}

public double getMean(ArrayList> elements,int index)

{

double sum=0.0;

double Mean;

for(ArrayList element:elements)

{

sum+=element.get(index);

}

Mean=sum/(double)elements.size();

return Mean;

}

public double getSdev(ArrayList> elements,int index)

{

double dev=0.0;

double Mean;

Mean=getMean(elements,index);

for(ArrayList element:elements)

{

dev+=Math.pow((element.get(index)-Mean),2);

}

dev=Math.sqrt(dev/elements.size());

return dev;

}

}

}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值