#include<iostream>
#include<opencv2/opencv.hpp>
using namespace std;
using namespace cv;
using namespace cv::ml;
int main()
{
//创建显示分割的图片
const int width = 512;
const int height = 512;
Mat image = Mat::zeros(height,width,CV_8UC3);
//标签数据
int labels[20];
for (int i = 0; i < 10; i++)
{
labels[i] = 1;
}
for (int i = 10; i < 20; i++)
{
labels[i] = 2;
}
Mat labelsMat(20,1,CV_32SC1,labels);
//制作样本点
float trainDataArray[20][2];
RNG rng(12345);
for (int i = 0; i < 10; i++)
{
trainDataArray[i][0] = 400 + static_cast<float>(rng.gaussian(30));
trainDataArray[i][1] = 400 + static_cast<float>(rng.gaussian(30));
}
for (int i = 10; i < 20; i++)
{
trainDataArray[i][0] = 30 + static_cast<float>(rng.gaussian(30));
trainDataArray[i][1] = 30 + static_cast<float>(rng.gaussian(30));
}
Mat trainMat(20,2,CV_32FC1,trainDataArray);
//创建机器学习算法
Ptr<SVM> svmModel = SVM::create();
//配置算法种类
svmModel->setType(SVM::C_SVC);//svm算法中的C-SVM分类器
svmModel->setKernel(SVM::LINEAR);//svm算法中的线性核函数
//设置算法的六个参数
svmModel->setC(1);//松弛变量C
//终止条件
svmModel->setTermCriteria(TermCriteria(TermCriteria::EPS,1000, FLT_EPSILON));
//开始训练
Ptr<TrainData> traind = TrainData::create(trainMat,ROW_SAMPLE,labelsMat);
svmModel->train(traind);
//存储训练模型
svmModel->save("SVM_DATA.xml");
//预测
Vec3b red(0, 0, 255), green(0, 255, 0), blue(255,0,0);
Mat sampleMat;
for (int i = 0; i < image.rows; i++)
{
for (int j = 0; j < image.cols; j++)
{
sampleMat = (Mat_<float>(1, 2) << j, i);
int response = svmModel->predict(sampleMat);
if (response == 1)
{
image.at<Vec3b>(i, j) = red;
}
if (response==2)
{
image.at<Vec3b>(i, j) = green;
}
}
}
//在分割图上显示样本点
for (int i = 0; i < trainMat.rows; i++)
{
const float* v = trainMat.ptr<float>(i);
Point pt = Point((int)v[0],(int)v[1]);
if (labels[i]==1)
{
circle(image, pt, 5, Scalar::all(255), -1, 8);
}
else if (labels[i] == 2)
{
circle(image, pt, 5, Scalar::all(128), -1, 8);
}
else
{
circle(image, pt, 5, Scalar::all(0), -1, 8);
}
}
//显示分类结果图像
imshow("SVM", image);
waitKey(0);
return 0;
}
对于训练好的模型保存在xml文件中
可以加载训练好的模型进行预测
#include<iostream>
#include<opencv2/opencv.hpp>
using namespace std;
using namespace cv;
using namespace cv::ml;
int main()
{
//创建显示分割的图片
const int width = 512;
const int height = 512;
Mat image = Mat::zeros(height,width,CV_8UC3);
//导入模型
Ptr<SVM> svmModel = Algorithm::load<SVM>("SVM_DATA.xml");
svmModel->load("SVM_DATA.xml");
//预测
Vec3b red(0, 0, 255), green(0, 255, 0), blue(255,0,0);
Mat sampleMat;
for (int i = 0; i < image.rows; i++)
{
for (int j = 0; j < image.cols; j++)
{
sampleMat = (Mat_<float>(1, 2) << j, i);
int response = svmModel->predict(sampleMat);
if (response == 1)
{
image.at<Vec3b>(i, j) = red;
}
if (response==2)
{
image.at<Vec3b>(i, j) = green;
}
}
}
//显示分类结果图像
imshow("SVM", image);
waitKey(0);
return 0;
}