用 OpenCVSharp 4.5 跑一遍 OpenCV 官方教程
原 OpenCV官方教程链接:Support Vector Machines for Non-Linearly Separable Data
using System;
using OpenCvSharp;
using OpenCvSharp.ML;
namespace ConsoleApp1
{
class tutorial41 : ITutorial
{
public void Run()
{
int NTRAINING_SAMPLES = 100; // Number of training samples per class
float FRAC_LINEAR_SEP = 0.9f; // Fraction of samples which compose the linear separable part
// Data for visual representation
int WIDTH = 512, HEIGHT = 512;
Mat I = Mat.Zeros(HEIGHT, WIDTH, MatType.CV_8UC3);
//--------------------- 1. Set up training data randomly ---------------------------------------
Mat trainData = new Mat(2 * NTRAINING_SAMPLES, 2, MatType.CV_32F);
Mat labels = new Mat(2 * NTRAINING_SAMPLES, 1, MatType.CV_32S);
RNG rng = new RNG(100); // Random value generation class
// Set up the linearly separable part of the training data
int nLinearSamples = (int)(FRAC_LINEAR_SEP * NTRAINING_SAMPLES);
// Generate random points for the class 1
Mat trainClass = trainData.RowRange(0, nLinearSamples);
// The x coordinate of the points is in [0, 0.4)
Mat c = trainClass.ColRange(0, 1);
rng.Fill(c, DistributionType.Uniform, new Scalar(0), new Scalar(0.4 * WIDTH));
// The y coordinate of the points is in [0, 1)
c = trainClass.ColRange(1, 2);
rng.Fill(c, DistributionType.Uniform, new Scalar(0), new Scalar(HEIGHT));
// Generate random points for the class 2
trainClass = trainData.RowRange(2 * NTRAINING_SAMPLES - nLinearSamples, 2 * NTRAINING_SAMPLES);
// The x coordinate of the points is in [0.6, 1]
c = trainClass.ColRange(0, 1);
rng.Fill(c, DistributionType.Uniform, new Scalar(0.6 * WIDTH), new Scalar(WIDTH));
// The y coordinate of the points is in [0, 1)
c = trainClass.ColRange(1, 2);
rng.Fill(c, DistributionType.Uniform, new Scalar(0), new Scalar(HEIGHT));
//------------------ Set up the non-linearly separable part of the training data ---------------
// Generate random points for the classes 1 and 2
trainClass = trainData.RowRange(nLinearSamples, 2 * NTRAINING_SAMPLES - nLinearSamples);
// The x coordinate of the points is in [0.4, 0.6)
c = trainClass.ColRange(0, 1);
rng.Fill(c, DistributionType.Uniform, new Scalar(0.4 * WIDTH), new Scalar(0.6 * WIDTH));
// The y coordinate of the points is in [0, 1)
c = trainClass.ColRange(1, 2);
rng.Fill(c, DistributionType.Uniform, new Scalar(0), new Scalar(HEIGHT));
//------------------------- Set up the labels for the classes ---------------------------------
labels.RowRange(0, NTRAINING_SAMPLES).SetTo(1); // Class 1
labels.RowRange(NTRAINING_SAMPLES, 2 * NTRAINING_SAMPLES).SetTo(2); // Class 2
//------------------------ 2. Set up the support vector machines parameters --------------------
Console.WriteLine("Starting training process");
SVM svm = OpenCvSharp.ML.SVM.Create();
svm.Type = SVM.Types.CSvc;
svm.C = 0.1;
svm.KernelType = SVM.KernelTypes.Linear;
svm.TermCriteria = new TermCriteria(CriteriaTypes.MaxIter, (int)1e7, 1e-6);
//------------------------ 3. Train the svm ----------------------------------------------------
svm.Train(trainData, SampleTypes.RowSample, labels);
Console.WriteLine("Finished training process");
//------------------------ 4. Show the decision regions ----------------------------------------
Vec3b green = new Vec3b(0, 100, 0), blue = new Vec3b(100, 0, 0);
for (int i = 0; i < I.Rows; i++)
{
for (int j = 0; j < I.Cols; j++)
{
Mat sampleMat = new Mat(1, 2, MatType.CV_32F, new float[] { j, i });
float response = svm.Predict(sampleMat);
if (response == 1) I.At<Vec3b>(i, j) = green;
else if (response == 2) I.At<Vec3b>(i, j) = blue;
}
}
//----------------------- 5. Show the training data --------------------------------------------
int thick = -1;
float px, py;
// Class 1
for (int i = 0; i < NTRAINING_SAMPLES; i++)
{
px = trainData.At<float>(i, 0);
py = trainData.At<float>(i, 1);
Cv2.Circle(I, new Point((int)px, (int)py), 3, new Scalar(0, 255, 0), thick);
}
// Class 2
for (int i = NTRAINING_SAMPLES; i < 2 * NTRAINING_SAMPLES; i++)
{
px = trainData.At<float>(i, 0);
py = trainData.At<float>(i, 1);
Cv2.Circle(I, new Point((int)px, (int)py), 3, new Scalar(255, 0, 0), thick);
}
//------------------------- 6. Show support vectors --------------------------------------------
thick = 2;
//下面代码中 svm.GetSupportVectors()似乎有问题,返回的数据不对 --- 暂时没搞定;
Mat sv = svm.GetSupportVectors();
for (int i = 0; i < sv.Rows; i++)
{
Point2f v = sv.At<Point2f>(i);
Cv2.Circle(I, new Point((int)v.X, (int)v.Y), 6, new Scalar(128, 128, 128), thick);
}
Cv2.ImWrite("result.png", I); // save the Image
Cv2.ImShow("SVM for Non-Linear Training Data", I); // show it to the user
Cv2.WaitKey();
return;
}
}
}