#include <opencv2/core.hpp>
#include <opencv2/imgproc.hpp>
#include "opencv2/imgcodecs.hpp"
#include <opencv2/highgui.hpp>
#include <opencv2/ml.hpp>
#include <iostream>
using namespace cv;
using namespace cv::ml;
/*训练部分*/
void train()
{
// Set up training data
int labels[4] = { 1, -1, -1, -1 };
Mat labelsMat(4, 1, CV_32SC1, labels);
float trainingData[4][2] = { { 501, 10 },{ 255, 10 },{ 501, 255 },{ 10, 501 } };
Mat trainingDataMat(4, 2, CV_32FC1, trainingData);
// Set up SVM's parameters
Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();
svm->setType(cv::ml::SVM::C_SVC);
svm->setKernel(cv::ml::SVM::RBF);
svm->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 100, 1e-6));
svm->setGamma(0.01);
svm->setC(800); //经验系数
svm->setP(0.1);
std::cout << "C为:" << svm->getC() << std::endl;
// Train the SVM
svm->train(trainingDataMat, cv::ml::ROW_SAMPLE, labelsMat);
svm->save("f:\\train_model.xml"); //保存模型
}
/*预测部分*/
void predict()
{
String path = "f:\\train_model.xml";
FileStorage svm_fs(path, FileStorage::READ); //读取文件
if (svm_fs.isOpened())
{
//testdata
int labels[4] = { 1, -1, -1, -1 };
Mat labelsMat(4, 1, CV_32SC1, labels);
float testingData[4][2] = { { 501, 10 },{ 255, 10 },{ 501, 255 },{ 10, 501 } };
Mat testingDataMat(4, 2, CV_32FC1, testingData);
//Ptr<ml::SVM> svm = ml::SVM::create();
//svm->load(path.c_str()); //从文件加载,这样是不对的—_—
Ptr<ml::SVM> svm = ml::SVM::load(path.c_str()); //从文件加载
std::cout << "C为:" << svm->getC() << std::endl; //读取一个参数检测是否加载成功
Mat result;
for (int i = 0; i < 4; i++)
{
Mat sample = testingDataMat.row(i);
float result = svm->predict(sample);
std::cout << "结果为:" << result << std::endl;
}
}
}
int main(int, char**)
{
train();
predict();
}
由于刚接触之前调试遇到了加载模型文件错误的问题,这段使用OpenCV3的代码是根据OpenCV2的例程改写的,遇到了模型文件加载错误的问题,具体表现为load以后svm并没有得到xml文件中的参数值,并且在predict时会报错。
最后发现直接load就可以不需要新建一个svm对象,load本身就可以创建svm了
/** @brief Loads and creates a serialized svm from a file
*
* Use SVM::save to serialize and store an SVM to disk.
* Load the SVM from this file again, by calling this function with the path to the file.
*
* @param filepath path to serialized svm
*/
CV_WRAP static Ptr<SVM> load(const String& filepath);
是真滴菜。。。