与上一篇的区别在于参数setc()的设置。
#include <iostream>
#include<opencv2/opencv.hpp>
#define NTRAINING_SAMPLES 100 // Number of training samples per class
#define FRAC_LINEAR_SEP 0.9f // Fraction of samples which compose the linear separable part
using namespace cv;
using namespace std;
using namespace cv::ml;
int main()
{
// Data for visual representation
const int WIDTH = 512, HEIGHT = 512;
Mat I = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
//--------------------- 1. Set up training data randomly ---------------------------------------
Mat trainData(2 * NTRAINING_SAMPLES, 2, CV_32FC1);
Mat labels(2 * NTRAINING_SAMPLES, 1, CV_32SC1);
RNG rng(100); // Random value generation class
// Set up the linearly separable part of the training data
int nLinearSamples = (int)(FRAC_LINEAR_SEP * NTRAINING_SAMPLES);
// Generate random points for the class 1
Mat trainClass = trainData.rowRange(0, nLinearSamples);
// The x coordinate of the points is in [0, 0.4)
Mat c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(0.4 * WIDTH));
// The y coordinate of the points is in [0, 1)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));
// Generate random points for the class 2
trainClass = trainData.rowRange(2 * NTRAINING_SAMPLES - nLinearSamples, 2 * NTRAINING_SAMPLES);
// The x coordinate of the points is in [0.6, 1]
c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(0.6 * WIDTH), Scalar(WIDTH));
// The y coordinate of the points is in [0, 1)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));
//------------------ Set up the non-linearly separable part of the training data ---------------
// Generate random points for the classes 1 and 2
trainClass = trainData.rowRange(nLinearSamples, 2 * NTRAINING_SAMPLES - nLinearSamples);
// The x coordinate of the points is in [0.4, 0.6)
c = trainClass.colRange(0, 1);
rng.fill(c, RNG::UNIFORM, Scalar(0.4 * WIDTH), Scalar(0.6 * WIDTH));
// The y coordinate of the points is in [0, 1)
c = trainClass.colRange(1, 2);
rng.fill(c, RNG::UNIFORM, Scalar(1), Scalar(HEIGHT));
//------------------------- Set up the labels for the classes ---------------------------------
labels.rowRange(0, NTRAINING_SAMPLES).setTo(1); // Class 1
labels.rowRange(NTRAINING_SAMPLES, 2 * NTRAINING_SAMPLES).setTo(2); // Class 2
//------------------------ 2. Set up the support vector machines parameters --------------------
Ptr<SVM> svmModel = SVM::create();
svmModel->setType(SVM::C_SVC);
svmModel->setKernel(SVM::LINEAR);
svmModel->setC(0.1);
svmModel->setTermCriteria(TermCriteria(TermCriteria::EPS,10000,1e-6));
//------------------------ 3. Train the svm ----------------------------------------------------
cout << "Starting training process" << endl;
Ptr<TrainData> traind = TrainData::create(trainData,ROW_SAMPLE,labels);
svmModel->train(traind);
cout << "Finished training process" << endl;
svmModel->save("SVM_DATA2.xml");
//------------------------ 4. Show the decision regions ----------------------------------------
Vec3b green(0, 100, 0), blue(100, 0, 0);
for (int i = 0; i < I.rows; ++i)
for (int j = 0; j < I.cols; ++j)
{
Mat sampleMat = (Mat_<float>(1, 2) << i, j);
float response = svmModel->predict(sampleMat);
if (response == 1) I.at<Vec3b>(j, i) = green;
else if (response == 2) I.at<Vec3b>(j, i) = blue;
}
//----------------------- 5. Show the training data --------------------------------------------
int thick = -1;
int lineType = 8;
float px, py;
// Class 1
for (int i = 0; i < NTRAINING_SAMPLES; ++i)
{
px = trainData.at<float>(i, 0);
py = trainData.at<float>(i, 1);
circle(I, Point((int)px, (int)py), 3, Scalar(0, 255, 0), thick, lineType);
}
// Class 2
for (int i = NTRAINING_SAMPLES; i < 2 * NTRAINING_SAMPLES; ++i)
{
px = trainData.at<float>(i, 0);
py = trainData.at<float>(i, 1);
circle(I, Point((int)px, (int)py), 3, Scalar(255, 0, 0), thick, lineType);
}
//------------------------- 6. Show support vectors --------------------------------------------
/*thick = 6;
lineType = 8;
Mat x = svmModel->getSupportVectors();
for (int i = 0; i < x.rows; ++i)
{
const float* v = x.ptr<float>(i);
cout << v[0] << v[1] << endl;
circle(I, Point((int)v[0], (int)v[1]), 6, Scalar(0, 0, 255), thick, lineType);
}*/
//imwrite("result.png", I); // save the Image
imshow("线性不可分二类问题", I); // show it to the user
waitKey(0);
return 0;
}