#include "opencv2/core.hpp" // 引入opencv2的核心头文件
#include "opencv2/video/tracking.hpp" // 引入opencv2视频跟踪相关功能的头文件
#include "opencv2/imgproc.hpp" // 引入opencv2的图像处理相关功能的头文件
#include "opencv2/highgui.hpp" // 引入opencv2的GUI界面和图像显示相关功能的头文件
#include "opencv2/ml.hpp" // 引入opencv2的机器学习模块的头文件
using namespace cv; // 使用opencv命名空间
using namespace cv::ml; // 使用opencv机器学习模块的命名空间
struct Data
{
Mat img; // 存放图像
Mat samples; // 训练样本集,其中包含图像上的点
Mat responses; // 训练样本的响应集
Data() // Data 的构造函数
{
const int WIDTH = 841; // 图像宽度
const int HEIGHT = 594; // 图像高度
img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); // 创建全零值(黑色)图像,3个颜色通道
imshow("Train svmsgd", img); // 显示图像,窗口标题为 "Train svmsgd"
}
};
// 使用 SVMSGD 算法进行训练的函数
// samples 和 responses 是训练集
// weights 是 SVMSGD 算法的决策函数所需的向量
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);
// 找到画线(wx = 0)的两个点的函数
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height);
// 找到线(wx = 0)和边界的交点的函数,边界为 (y = HEIGHT, 0 <= x <= WIDTH) 或 (x = WIDTH, 0 <= y <= HEIGHT)
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
// 初始化边界线段的函数,分为 (y = HEIGHT, 0 <= x <= WIDTH) 和 (x = WIDTH, 0 <= y <= HEIGHT)
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);
// 重绘图像中的点集合和线(wx = 0)
void redraw(Data data, const Point points[2]);
// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
// 实际执行训练的函数实现
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
{
cv::Ptr<SVMSGD> svmsgd = SVMSGD::create(); // 创建 SVMSGD 类的实例
cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); // 创建训练数据
svmsgd->train(trainData); // 训练 SVMSGD 模型
if (svmsgd->isTrained()) // 如果模型训练成功
{
weights = svmsgd->getWeights(); // 获取模型权重
shift = svmsgd->getShift(); // 获取模型偏移
return true; // 返回训练成功
}
return false; // 如果训练失败,返回 false
}
// 初始化边界线段的函数实现
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{
std::pair<Point,Point> currentSegment;
currentSegment.first = Point(width, 0); // 右边界线段
currentSegment.second = Point(width, height);
segments.push_back(currentSegment);
currentSegment.first = Point(0, height); // 底边界线段
currentSegment.second = Point(width, height);
segments.push_back(currentSegment);
currentSegment.first = Point(0, 0); // 顶边界线段
currentSegment.second = Point(width, 0);
segments.push_back(currentSegment);
currentSegment.first = Point(0, 0); // 左边界线段
currentSegment.second = Point(0, height);
segments.push_back(currentSegment);
}
// 函数findCrossPointWithBorders用于计算给定权重和偏移量下,直线与图像边界的交点
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{
// 初始化交点的坐标
int x = 0;
int y = 0;
// 获取线段端点的横纵坐标极值
int xMin = std::min(segment.first.x, segment.second.x);
int xMax = std::max(segment.first.x, segment.second.x);
int yMin = std::min(segment.first.y, segment.second.y);
int yMax = std::max(segment.first.y, segment.second.y);
// 权重矩阵必须是单精度浮点类型
CV_Assert(weights.type() == CV_32FC1);
// 检查线段水平还是垂直
CV_Assert(xMin == xMax || yMin == yMax);
// 如果是垂直线段并且权重矩阵的第二个元素不为0
if (xMin == xMax && weights.at<float>(1) != 0)
{
x = xMin;
// 根据直线方程计算交点的y值
y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
// 检测交点是否在图像的边界内
if (y >= yMin && y <= yMax)
{
crossPoint.x = x;
crossPoint.y = y;
return true;
}
}
// 如果是水平线段并且权重矩阵的第一个元素不为0
else if (yMin == yMax && weights.at<float>(0) != 0)
{
y = yMin;
// 根据直线方程计算交点的x值
x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
// 检测交点是否在图像的边界内
if (x >= xMin && x <= xMax)
{
crossPoint.x = x;
crossPoint.y = y;
return true;
}
}
return false;
}
// 函数findPointsForLine用于计算用于绘制直线的两个点
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
{
// 如果权重矩阵为空,则返回失败
if (weights.empty())
{
return false;
}
// 记录已找到的有效交点数量
int foundPointsCount = 0;
// 用于存储图像边框的4条线段(轮廓线)
std::vector<std::pair<Point,Point> > segments;
// 初始化线段(轮廓线)
fillSegments(segments, width, height);
// 遍历所有的边框线
for (uint i = 0; i < segments.size(); i++)
{
// 如果找到与边框的交点
if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
foundPointsCount++; //增加有效交点的数量
// 如果已找到了两个有效交点,则可以构成一条线,跳出循环
if (foundPointsCount >= 2)
break;
}
return true; //成功找到两个点
}
// 重新绘制数据集和分割线的函数实现
void redraw(Data data, const Point points[2])
{
data.img.setTo(0); // 将图像设置为全黑
Point center;
int radius = 3; // 点的半径
Scalar color; // 颜色
CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1)); // 断言样本和响应数据类型正确
for (int i = 0; i < data.samples.rows; i++) // 遍历所有样本
{
center.x = static_cast<int>(data.samples.at<float>(i,0));
center.y = static_cast<int>(data.samples.at<float>(i,1));
color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); // 根据响应值设定颜色
circle(data.img, center, radius, color, 5); // 绘制圆形点
}
line(data.img, points[0], points[1],cv::Scalar(1,255,1)); // 绘制分割线
imshow("Train svmsgd", data.img); // 显示图像
}
// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数实现
// 函数addPointRetrainAndRedraw用于添加新的训练点,并重新训练SVMSGD算法,然后重绘图形
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{
// 创建一个1行2列的单精度浮点数矩阵,用于存储一个样本点
Mat currentSample(1, 2, CV_32FC1);
// 设定样本点的横纵坐标
currentSample.at<float>(0,0) = (float)x;
currentSample.at<float>(0,1) = (float)y;
// 将新样本点加入到样本集中
data.samples.push_back(currentSample);
// 将样本点的响应(类别)加入到响应集中
data.responses.push_back(static_cast<float>(response));
// 创建权重矩阵和偏移量
Mat weights(1, 2, CV_32FC1);
float shift = 0;
// 如果训练成功
if (doTrain(data.samples, data.responses, weights, shift))
{
// 创建Points数组用于存储线的两个点
Point points[2];
// 找到用于绘制直线的两个点
findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
// 使用找到的两个点重新绘制图形
redraw(data, points);
}
}
// 鼠标回调函数,用于在图像上添加正负样本点并重新训练模型
static void onMouse( int event, int x, int y, int, void* pData)
{
Data &data = *(Data*)pData; // 从pData转换获取Data结构体引用
switch( event ) // 根据事件类型
{
case EVENT_LBUTTONUP: // 左键松开事件
addPointRetrainAndRedraw(data, x, y, 1); // 添加正样本点并重新训练绘制
break;
case EVENT_RBUTTONDOWN: // 右键按下事件
addPointRetrainAndRedraw(data, x, y, -1);// 添加负样本点并重新训练绘制
break;
}
}
// 主函数
int main()
{
Data data; // 创建Data结构体实例
setMouseCallback( "Train svmsgd", onMouse, &data ); // 设置鼠标回调函数
waitKey(); // 等待按键
return 0; // 程序结束
}
该段代码是一个关于OpenCV和机器学习算法SVMSGD(支持向量机随机梯度下降)的简单示例,用于创建一个可交互的界面,在上面添加样本点,进行实时的线性分类器训练,并且通过绘制决策边界来显示分类结果。通过鼠标左键添加正样本,右键添加负样本。