此代码适用于opencv3
数据集分开训练数据集和测试数据集合;
训练模型代码
Ptr<ml::TrainData> train_data, test_data;
train_data = ml::TrainData::loadFromCSV("../data/data_csv/data_train.csv",',');
int i = train_data->getLayout();
cout << i << " " << cv::ml::ROW_SAMPLE << endl;
test_data = ml::TrainData::loadFromCSV("../data/bayes_data/data_test.csv",',');
cv::Ptr< cv::ml::NormalBayesClassifier> bestClassifier;
double train_error = 100;
double test_error = 100;
Mat testSamples = test_data->getTrainSamples();
Mat TestResponses = test_data->getTrainResponses() ;
cv::Mat TestSamples;
normalize(testSamples, TestSamples, CV_BGR2HSV);
TestSamples.convertTo(TestSamples,CV_32FC1);
TestResponses.convertTo(TestResponses,CV_32S);
for (double ratio = 0.5; ratio < 1 ; ratio =ratio +0.05){
train_data->setTrainTestSplitRatio(ratio, true);
Mat TrainSamples = train_data->getTrainSamples();
Mat TrainResponses = train_data->getTrainResponses() ;
Mat TestSamples = train_data->getTestSamples();
Mat TestResponses = train_data->getTestResponses() ;
normalize(TrainSamples, TrainSamples, CV_BGR2HSV);
normalize(TestSamples, TestSamples, CV_BGR2HSV);
TestSamples.convertTo(TrainSamples,CV_32FC1);
TestResponses.convertTo(TestResponses,CV_32S);
Ptr<ml::NormalBayesClassifier> model=ml::NormalBayesClassifier::create();
cout <<"ratio " << ratio << endl;
bool ok = model->train(train_data);
if( !ok )
{
printf("Training failed\n");
}
else
{
float temp_train = model->calcError(train_data, false, noArray());
float temp_test = model->calcError(train_data, true, noArray());
if( (temp_train < train_error) && (temp_test< test_error) ){
cout <<"save this model" << endl;
printf( "train error: %f %f\n", temp_train ,train_error);
printf( "test error: %f %f\n\n", temp_test, test_error);
model->save("../bestClassifier_80_5label.xml");
train_error =temp_train;
test_error =temp_test;
}
}
}
加载模型进行预测:
Ptr<ml::NormalBayesClassifier> svm = ml::StatModel::load<ml::NormalBayesClassifier>("../bestClassifier_80_5label.xml");
cv::Mat testMat_2= (Mat_<double>( 32, 5) << 0,0,0,0,0,
0,0,0,0,1,
0,0,0,1,0,
0,0,1,0,0,
0,1,0,0,0);
testMat_2.convertTo(data,CV_32FC1);
svm->predict(data, res2);
cout << res2<< endl;