opencv-bayes模型训练以及加载

1 篇文章 0 订阅
1 篇文章 0 订阅

此代码适用于opencv3

数据集分开训练数据集和测试数据集合;

训练模型代码

   Ptr<ml::TrainData> train_data, test_data;

   train_data = ml::TrainData::loadFromCSV("../data/data_csv/data_train.csv",',');

    int i = train_data->getLayout();
    cout << i << " " << cv::ml::ROW_SAMPLE << endl;
    test_data = ml::TrainData::loadFromCSV("../data/bayes_data/data_test.csv",',');
    cv::Ptr< cv::ml::NormalBayesClassifier> bestClassifier;
    double train_error = 100;
    double test_error = 100;

    Mat testSamples = test_data->getTrainSamples();
    Mat TestResponses = test_data->getTrainResponses()  ;

    cv::Mat TestSamples;
    normalize(testSamples, TestSamples, CV_BGR2HSV);
    TestSamples.convertTo(TestSamples,CV_32FC1);
    TestResponses.convertTo(TestResponses,CV_32S);

    for (double ratio = 0.5; ratio < 1 ; ratio =ratio +0.05){

        train_data->setTrainTestSplitRatio(ratio, true);

        Mat TrainSamples = train_data->getTrainSamples();
        Mat TrainResponses = train_data->getTrainResponses()  ;
        Mat TestSamples = train_data->getTestSamples();
        Mat TestResponses = train_data->getTestResponses()  ;
        normalize(TrainSamples, TrainSamples, CV_BGR2HSV);
        normalize(TestSamples, TestSamples, CV_BGR2HSV);

        TestSamples.convertTo(TrainSamples,CV_32FC1);
        TestResponses.convertTo(TestResponses,CV_32S);


        Ptr<ml::NormalBayesClassifier> model=ml::NormalBayesClassifier::create();
        cout <<"ratio " << ratio << endl;


        bool ok = model->train(train_data);
        if( !ok )
        {
            printf("Training failed\n");
        }
        else
        {
            float temp_train = model->calcError(train_data, false, noArray());
            float temp_test = model->calcError(train_data, true, noArray());

            if( (temp_train < train_error) && (temp_test< test_error)    ){
                cout <<"save this model" << endl;
                printf( "train error: %f  %f\n", temp_train ,train_error);
                printf( "test error: %f  %f\n\n", temp_test, test_error);
                model->save("../bestClassifier_80_5label.xml");
                train_error =temp_train;
                test_error =temp_test;

            }
        }

    }

加载模型进行预测:


Ptr<ml::NormalBayesClassifier> svm = ml::StatModel::load<ml::NormalBayesClassifier>("../bestClassifier_80_5label.xml");
cv::Mat testMat_2= (Mat_<double>( 32, 5) << 0,0,0,0,0,
        0,0,0,0,1,
        0,0,0,1,0,
        0,0,1,0,0,
        0,1,0,0,0);
testMat_2.convertTo(data,CV_32FC1);
svm->predict(data, res2);
cout << res2<< endl;
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值