OpenCV机器学习(1):贝叶斯分类器实现代码分析

OpenCV的机器学习类定义在ml.hpp文件中,基础类是CvStatModel,其他各种分类器从这里继承而来。

今天研究CvNormalBayesClassifier分类器。

1.类定义

在ml.hpp中有以下类定义:

class CV_EXPORTS_W CvNormalBayesClassifier : public CvStatModel
{
public:
    CV_WRAP CvNormalBayesClassifier();
    virtual ~CvNormalBayesClassifier();

    CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
        const CvMat* varIdx=0, const CvMat* sampleIdx=0 );

    virtual bool train( const CvMat* trainData, const CvMat* responses,
        const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );

    virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0 ) const;
    CV_WRAP virtual void clear();

    CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
                            const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
    CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
                       const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
                       bool update=false );
    CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0 ) const;

    virtual void write( CvFileStorage* storage, const char* name ) const;
    virtual void read( CvFileStorage* storage, CvFileNode* node );

protected:
    int     var_count, var_all;
    CvMat*  var_idx;
    CvMat*  cls_labels;
    CvMat** count;
    CvMat** sum;
    CvMat** productsum;
    CvMat** avg;
    CvMat** inv_eigen_values;
    CvMat** cov_rotate_mats;
    CvMat*  c;
};

2.示例

此类使用方法如下:(引用别人的代码,忘记出处了,非常抱歉这个。。。)

//openCV中贝叶斯分类器的API函数用法举例
//运行环境:win7 + VS2005 + openCV2.4.5

#include "global_include.h"

using namespace std;
using namespace cv;

//10个样本特征向量维数为12的训练样本集,第一列为该样本的类别标签
double inputArr[10][13] = 
{
     1,0.708333,1,1,-0.320755,-0.105023,-1,1,-0.419847,-1,-0.225806,0,1, 
    -1,0.583333,-1,0.333333,-0.603774,1,-1,1,0.358779,-1,-0.483871,0,-1,
     1,0.166667,1,-0.333333,-0.433962,-0.383562,-1,-1,0.0687023,-1,-0.903226,-1,-1,
    -1,0.458333,1,1,-0.358491,-0.374429,-1,-1,-0.480916,1,-0.935484,0,-0.333333,
    -1,0.875,-1,-0.333333,-0.509434,-0.347032,-1,1,-0.236641,1,-0.935484,-1,-0.333333,
    -1,0.5,1,1,-0.509434,-0.767123,-1,-1,0.0534351,-1,-0.870968,-1,-1,
     1,0.125,1,0.333333,-0.320755,-0.406393,1,1,0.0839695,1,-0.806452,0,-0.333333,
     1,0.25,1,1,-0.698113,-0.484018,-1,1,0.0839695,1,-0.612903,0,-0.333333,
     1,0.291667,1,1,-0.132075,-0.237443,-1,1,0.51145,-1,-0.612903,0,0.333333,
     1,0.416667,-1,1,0.0566038,0.283105,-1,1,0.267176,-1,0.290323,0,1
};

//一个测试样本的特征向量
double testArr[]=
{
    0.25,1,1,-0.226415,-0.506849,-1,-1,0.374046,-1,-0.83871,0,-1
};


int _tmain(int argc, _TCHAR* argv[])
{
    Mat trainData(10, 12, CV_32FC1);//构建训练样本的特征向量
    for (int i=0; i<10; i++)
    {
        for (int j=0; j<12; j++)
        {
            trainData.at<float>(i, j) = inputArr[i][j+1];
        }
    }

    Mat trainResponse(10, 1, CV_32FC1);//构建训练样本的类别标签
    for (int i=0; i<10; i++)
    {
        trainResponse.at<float>(i, 0) = inputArr[i][0];
    }

    CvNormalBayesClassifier nbc;
    bool trainFlag = nbc.train(trainData, trainResponse);//进行贝叶斯分类器训练
    if (trainFlag)
    {
        cout<<"train over..."<<endl;
        nbc.save("normalBayes.txt");
    }
    else
    {
        cout<<"train error..."<<endl;
   
  • 2
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 8
    评论
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值