Python图像处理(12):贝叶斯分类器

快乐虾

http://blog.csdn.net/lights_joy/

欢迎转载,但请保留作者信息


朴素贝叶斯分类算法是机器学习中十分经典而且应用十分广泛的算法,本文尝试用它进行数据点的分类。

OpenCV里面的分类器基本都是先训练,再预测,贝叶斯分类器也不例外。因此我们先生成训练数据,总共60个点:


# 训练的点数
train_pts = 30

# 创建测试的数据点,2类
# 以(-1.5, -1.5)为中心
rand1 = np.ones((train_pts,2)) * (-2) + np.random.rand(train_pts, 2)
print('rand1:')
print(rand1)

# 以(1.5, 1.5)为中心
rand2 = np.ones((train_pts,2)) + np.random.rand(train_pts, 2)
print('rand2:')
print(rand2)

# 合并随机点,得到训练数据
train_data = np.vstack((rand1, rand2))
train_data = np.array(train_data, dtype='float32')
train_label = np.vstack( (np.zeros((train_pts,1), dtype='int32'), np.ones((train_pts,1), dtype='int32')))

接下来就可以用train_datatrain_label进行训练了:

# 训练
bayer = cv2.ml.NormalBayesClassifier_create()
ret = bayer.train(train_data, cv2.ml.ROW_SAMPLE, train_label)

# 显示训练数据
plt.figure(1)
plt.plot(rand1[:,0], rand1[:,1], 'o')
plt.plot(rand2[:,0], rand2[:,1], 'o')

看看用于训练的点:


在训练完成后就可以用训练好的分类器进行预测:

# 测试数据,20个点[-2,2]
pt = np.array(np.random.rand(20,2) * 4 - 2, dtype='float32')
(ret, res) = bayer.predict(pt)
print("res = ")
print(res)

# 按label进行分类显示
plt.figure(2)
idx = np.hstack((res, res))
for i in range(0, 2) :
    type_data = pt[idx == i]
    type_data = np.reshape(type_data, (type_data.shape[0] / 2, 2))
    plt.plot(type_data[:,0], type_data[:,1], 'o')

plt.show()

看看测试数据和分类结果:



在使用此分类器的时候,发现opencvC++实现代码中有一个BUG,如果进行predict的测试数据是一个数组而不是一个点,opencv会执行时会停在下述代码的注释部分:


    float predictProb( InputArray _samples, OutputArray _results, OutputArray _resultsProb, int flags ) const
    {
        int value=0;
        Mat samples = _samples.getMat(), results, resultsProb;
        int nsamples = samples.rows, nclasses = (int)cls_labels.total();
        bool rawOutput = (flags & RAW_OUTPUT) != 0;

        if( samples.type() != CV_32F || samples.cols != nallvars )
            CV_Error( CV_StsBadArg,
                     "The input samples must be 32f matrix with the number of columns = nallvars" );

/*有问题的代码??*/
        //if( samples.rows > 1 && _results.needed() )
        //    CV_Error( CV_StsNullPtr,
        //             "When the number of input samples is >1, the output vector of results must be passed" );

        if( _results.needed() )
        {
            _results.create(nsamples, 1, CV_32S);
            results = _results.getMat();
        }
        else
            results = Mat(1, 1, CV_32S, &value);

        if( _resultsProb.needed() )
        {
            _resultsProb.create(nsamples, nclasses, CV_32F);
            resultsProb = _resultsProb.getMat();
        }

        cv::parallel_for_(cv::Range(0, nsamples),
                          NBPredictBody(c, cov_rotate_mats, inv_eigen_values, avg, samples,
                                       var_idx, cls_labels, results, resultsProb, rawOutput));

        return (float)value;
    }


实际上这个判断条件完全是多余的,直接去除即可。












  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

嵌云阁主

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值