【机器学习算法】之朴素贝叶斯

一.朴素贝叶斯算法介绍

关于本算法通俗的介绍可参考:
http://www.cnblogs.com/leoo2sk/archive/2010/09/17/naive-bayesian-classifier.html

我对朴素贝叶斯方法的理解:
统计学习=模型+策略+算法
1.模型:
朴素贝叶斯是一种生成式的学习算法,它致力于学习p(x,y),即输入向量和输出label的联合分布。根据贝叶斯公式,p(x,y)=p(x|y)*p(y),所以学习p(x,y)也就是学习p(x|y)和p(y)。
朴素贝叶斯法做了以下假定:
y和x都是离散的,并且p(x|y)=p(x1|y)*p(x2|y)***p(xn|y)。
2.策略:
按照我的理解,策略从某种意义上可以视为是评价模型好坏的指标,我们希望学习得到什么样的模型,那么就应该设计什么样的策略。在朴素贝叶斯算法里面,我们选取的是经验风险最小化策略,即利用极大似然估计来求解p(x}y)和p(y)。
《统计学习方法》中有一段是论证朴素贝叶斯用的分类规则:最大后验概率与期望风险最小化等价。个人认为这段话容易使人产生误解,朴素贝叶斯算法的目的只是求解出p(x,y)。在求解出p(x,y)之后,才是如何对y进行分类的问题,也就是说这段论证应该放在章节的最后比较合适。
3.算法:
根据最大似然法则列出表达式之后,下一步就是怎么去求解。算法指的就是求解表达式最大值或者最小值的方法。在这里就是简单的求导即可。详见《ufldl note2》.

二.朴素贝叶斯算法C++实现

参照http://blog.csdn.net/lavorange/article/details/17841383?utm_source=tuicool&utm_medium=referral
这位博主是按照《机器学习实战》的程序结构,把python的程序转换成c++。但是《机器学习实战》这本书中给出的程序是有小错误的,我已经改成了。

/*
 * code list 4-1 : transfer func from docs list to vocabulary list
 * code list 4-2 : training func on Naive Bayes Classifier
 * code list 4-3 : naive bayes classify function
 * add code list 4-4 : naive bayes bag-of-word model
 * add code list 4-5 : text parse : textParse.py and spam email test function : get_error_rate()
 * */

#include<iostream>
#include<map>
#include<set>
#include<cmath>
#include<vector>
#include<algorithm>
#include<numeric>
#include<string>
#include<stdio.h>
#include<cstdlib>
#include<fstream>
#include<stdlib.h>
#include<unistd.h>

using namespace std;

class NaiveBayes
{
    private:
        vector< vector<string> > list_of_docs;
        vector<int> list_classes;
        map<string,int>  my_vocab_list;
        int *return_vec;
        vector< vector<int> > train_mat;
        vector<float> p0vect;
        vector<float> p1vect;
        float p_abusive;
        ifstream fin;
        ofstream fout;
        int test_data_num;

    public:
        NaiveBayes()
        {
            cout<<"please input the num of test data which should be less than 24 : "<<endl;
            cin>>test_data_num;
            vector<string> vec;
            string word;
            string filename;
            char buf[3];
            string buf_str;
            for(int i=test_data_num+1;i<=25;i++)
            {
                sprintf(buf,"%d",i);  //convert digit to string
                vec.clear();
                buf_str = buf;
                filename = "./email/hamParse/"+buf_str+".dat";
                //cout<<"filename : "<<filename<<endl;
                fin.open( filename.c_str() );
                if(!fin)
                {
                    cerr<<"open the file "<<filename<<" error"<<endl;
                    exit(1);
                }
                while(fin>>word)
                {
                    vec.push_back(word);
                }
                list_of_docs.push_back( vec );
                list_classes.push_back(0);
                filename.clear();
                fin.close();
            }

            for(int i=test_data_num+1;i<=25;i++)
            {
                sprintf(buf,"%d",i);
                vec.clear();
                buf_str = buf;
                filename =  "./email/spamParse/"+buf_str+".dat";
                //cout<<"filename : "<<filename<<endl;
                fin.open( filename.c_str() );
                if(!fin)
                {
                    cerr<<"open the file "<<filename<<" error"<<endl;
                }
                while(fin>>word)
                {
                    vec.push_back(word);
                }
                list_of_docs.push_back( vec );
                list_classes.push_back(1);
                filename.clear();
                fin.close();
            }

        }

        ~NaiveBayes()
        {
            fin.close();
            fout.close();
            list_of_docs.clear();
            list_classes.clear();
            my_vocab_list.clear();
            train_mat.clear();
            //delete [] return_vec;
            p0vect.clear();
            p1vect.clear();
        }


        void create_vocab_list()
        {
            vector< vector<string> > :: iterator it = list_of_docs.begin();
            int index = 1;
            while( it!=list_of_docs.end() )
            {
                //vector<string> vec( *it.begin(),*it.end() );
                vector<string> vec = *it;

                vector<string> :: iterator tmp_it = vec.begin();

                while( tmp_it!=vec.end() )
                {
                    //cout<<*tmp_it<<" ";
                    if( my_vocab_list[*tmp_it] == 0 )
                    {
                        my_vocab_list[*tmp_it] = index++; //index is the location of the vovabulary
                    }
                    tmp_it++;
                }
                it++;
            }

        }//create_vocab_list

        //set some one word to vec with 0 and 1.
        void beg_of_words_to_vec(int idx)
        {
            //cout<<"set of words to vec begin the document id is : "<<idx<<endl;
            int len = my_vocab_list.size()+1;
            return_vec = new int[ len ](); //pay attention to the difference between "new int[len]". initalize all the element to zero.
            fill(return_vec,return_vec+len,0);
            vector< vector<string> >:: iterator it = list_of_docs.begin() + idx - 1  ;
            vector<string> vec  = *it;
            vector<string> :: iterator itt = vec.begin();
            int pos = 0 ;
            while( itt!=vec.end() )
            {
    //          cout<<*itt<<" ";
                pos = my_vocab_list[ *itt ];
                if(pos!=0)
                {
                    return_vec[pos] += 1;
                }
                itt++;
            }
        }//beg_of_words_to_vec

        void get_train_matrix()
        {
            cout<<"get train matrix begin : "<<endl;
            train_mat.clear();
            for(int i=1;i<=list_of_docs.size();i++)
            {
                beg_of_words_to_vec(i);
                vector<int> vec( return_vec , return_vec + my_vocab_list.size()+1 );
                train_mat.push_back(vec);
                delete []return_vec;
            }
        }//get train matrix

        void print()
        {
            cout<<"print the train matrix begin : "<<endl;
            vector< vector<int> > :: iterator it = train_mat.begin();
            while(it!=train_mat.end())
            {
                vector<int> vec = *it;
                vector<int> :: iterator itt = vec.begin();
                while( itt!=vec.end())
                {
                    cout<<*itt<<" ";
                    itt++;
                }
                cout<<endl;
                it++;
            }

        }//print()

        void train_NB0()
        {
            int num_train_docs = train_mat.size();//sizeof(docs_lists)/sizeof(docs_lists[0]);
            cout<<"num_train_docs = "<<num_train_docs<<endl;
            int num_words = train_mat[0].size() - 1 ;
            /* calculatr the sum of the abusive classes */  
            int sum = accumulate(list_classes.begin(),list_classes.end(),0);
            cout<<"sum = "<<sum<<endl;
            //float p_abusive = (float)sum/(float)num_train_docs;
            p_abusive =  (float)sum/(float)num_train_docs;
            cout<<"p_abusive = "<<p_abusive<<endl;

            //vector<float> p0vect(train_mat[0].size(),1); //the frequency of each word in non-absusive docs
            p0vect.resize(train_mat[0].size(),1);
            //vector<float> p1vect(train_mat[0].size(),1); //the frequency of each word in abusive docs
            p1vect.resize(train_mat[0].size(),1);
            printf("p0num.size() = %d , p1num.size() = %d\n",p0vect.size(),p1vect.size());
            float p0Denom = 2.0; //the total number of words in non-abusive docs
            float p1Denom = 2.0; //the total number of words in abusive docs

            /* calculate the p0num,p1num,p0Denom,p1Denom */
            for(int i=0;i<list_classes.size();i++)
            {
                if(list_classes[i] == 1)  //abusive doc
                {
                    p1Denom++;
                    for(int j=0;j<p1vect.size();j++)
                    {
                        p1vect[j] += train_mat[i][j];


                    }
                }
                else   //non-abusive doc
                {
                    p0Denom++;
                    for(int j=0;j<p0vect.size();j++)
                    {
                        p0vect[j] += train_mat[i][j];


                    }
                }
            }

            for(int i=0;i<p1vect.size();i++)
            {
                p0vect[i] = log(p0vect[i]/p0Denom);
                p1vect[i] = log(p1vect[i]/p1Denom);
            }

            cout<<endl;
        }

        int classify_NB(const char  *filename )
        {
            return_vec = new int[ my_vocab_list.size()+1 ]();

            fin.open(filename);
            if(!fin)
            {
                cerr<<"fail to open the file "<<filename<<endl;
                exit(1);
            }
            string word;
            while(fin>>word)
            {
                int pos = my_vocab_list[ word ];
                if( pos!=0 )
                {
                    return_vec[ pos ] += 1;
                }
            }
            fin.close();

            cout<<endl;
            float p1 = inner_product( p1vect.begin()+1,p1vect.end(),return_vec+1,0 ) + log(p_abusive);
            float p0 = inner_product( p0vect.begin()+1,p0vect.end(),return_vec+1,0 ) + log(1-p_abusive);

            cout<<"p1 = "<<p1<<"  "<<"p0 = "<<p0<<endl;

            if( p1>p0 )
            {
                return 1;
            }
            else
            {
                return 0;
            }
        }

        void get_error_rate()
        {
            string filename ;
            char buf[3];
            string buf_str;
            int error_count = 0;
            for(int i=1;i<=test_data_num;i++)   
            {
                sprintf(buf,"%d",i);
                buf_str = buf;
                filename = "./email/hamParse/"+buf_str+".dat";
                if( classify_NB( filename.c_str() ) != 0 )
                {
                    error_count++;
                }

                filename = "./email/spamParse/"+buf_str+".dat";
                if( classify_NB( filename.c_str() ) != 1 )
                {
                    error_count++;
                }
            }       
            cout<<"the error rate is : "<<(float)error_count/(float)(2*test_data_num)<<endl;

        }
};

int main()
{
    NaiveBayes nb;
    nb.create_vocab_list();
    //nb.beg_of_words_to_vec(5);
    //nb.beg_of_words_to_vec(30);
    nb.get_train_matrix();
    //nb.print();
    nb.train_NB0();

    char  doc1_to_classify[] = "./email/hamParse/1.dat";
    char  doc2_to_classify[] = "./email/spamParse/1.dat";
    cout<<"doc1 classified as : "<<nb.classify_NB( doc1_to_classify )<<endl;
    cout<<"doc2 classified as : "<<nb.classify_NB( doc2_to_classify )<<endl;

    nb.get_error_rate();
    return 0;
}

三.朴素贝叶斯算法的python实现

def trainNB0(trainMatrix,trainCategory):
    numTrainDocs = len(trainMatrix)
    numWords = len(trainMatrix[0])
    pAbusive = sum(trainCategory)/float(numTrainDocs)
    p0Num = ones(numWords); p1Num = ones(numWords)      #change to ones() 
    p0Denom = 2.0; p1Denom = 2.0                        #change to 2.0
    for i in range(numTrainDocs):
        if trainCategory[i] == 1:
            p1Num += trainMatrix[i]
            p1Denom += 1
        else:
            p0Num += trainMatrix[i]
            p0Denom += 1
    p1Vect = log(p1Num/p1Denom)          #change to log()
    p0Vect = log(p0Num/p0Denom)          #change to log()
    return p0Vect,p1Vect,pAbusive
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值