【opencv】示例-train_svmsgd.cpp 随机梯度下降支持向量机(SVMSGD)对二维数据进行分类的UI...

6389938ae1f28455547f39da35bee25c.png

493b36e63a1aa84805bfa9215f553ff3.png

#include "opencv2/core.hpp"                     // 引入opencv2的核心头文件
#include "opencv2/video/tracking.hpp"           // 引入opencv2视频跟踪相关功能的头文件
#include "opencv2/imgproc.hpp"                  // 引入opencv2的图像处理相关功能的头文件
#include "opencv2/highgui.hpp"                  // 引入opencv2的GUI界面和图像显示相关功能的头文件
#include "opencv2/ml.hpp"                       // 引入opencv2的机器学习模块的头文件


using namespace cv;                            // 使用opencv命名空间
using namespace cv::ml;                        // 使用opencv机器学习模块的命名空间


struct Data                                     
{
    Mat img;                                   // 存放图像
    Mat samples;                              // 训练样本集,其中包含图像上的点
    Mat responses;                            // 训练样本的响应集


    Data()                                    // Data 的构造函数
    {
        const int WIDTH = 841;                // 图像宽度
        const int HEIGHT = 594;               // 图像高度
        img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3); // 创建全零值(黑色)图像,3个颜色通道
        imshow("Train svmsgd", img);          // 显示图像,窗口标题为 "Train svmsgd"
    }
};
// 使用 SVMSGD 算法进行训练的函数
// samples 和 responses 是训练集
// weights 是 SVMSGD 算法的决策函数所需的向量
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);


// 找到画线(wx = 0)的两个点的函数
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height);


// 找到线(wx = 0)和边界的交点的函数,边界为 (y = HEIGHT, 0 <= x <= WIDTH) 或 (x = WIDTH, 0 <= y <= HEIGHT)
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);


// 初始化边界线段的函数,分为 (y = HEIGHT, 0 <= x <= WIDTH) 和 (x = WIDTH, 0 <= y <= HEIGHT)
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);


// 重绘图像中的点集合和线(wx = 0)
void redraw(Data data, const Point points[2]);


// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);


// 实际执行训练的函数实现
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift)
{
    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();   // 创建 SVMSGD 类的实例


    cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses); // 创建训练数据
    svmsgd->train(trainData); // 训练 SVMSGD 模型


    if (svmsgd->isTrained())  // 如果模型训练成功
    {
        weights = svmsgd->getWeights();  // 获取模型权重
        shift = svmsgd->getShift();      // 获取模型偏移


        return true;                     // 返回训练成功
    }
    return false;                        // 如果训练失败,返回 false
}


// 初始化边界线段的函数实现
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{
    std::pair<Point,Point> currentSegment;


    currentSegment.first = Point(width, 0);    // 右边界线段
    currentSegment.second = Point(width, height);
    segments.push_back(currentSegment);


    currentSegment.first = Point(0, height);   // 底边界线段
    currentSegment.second = Point(width, height);
    segments.push_back(currentSegment);


    currentSegment.first = Point(0, 0);        // 顶边界线段
    currentSegment.second = Point(width, 0);
    segments.push_back(currentSegment);


    currentSegment.first = Point(0, 0);        // 左边界线段
    currentSegment.second = Point(0, height);
    segments.push_back(currentSegment);
}




// 函数findCrossPointWithBorders用于计算给定权重和偏移量下,直线与图像边界的交点
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{
    // 初始化交点的坐标
    int x = 0;
    int y = 0;
    // 获取线段端点的横纵坐标极值
    int xMin = std::min(segment.first.x, segment.second.x);
    int xMax = std::max(segment.first.x, segment.second.x);
    int yMin = std::min(segment.first.y, segment.second.y);
    int yMax = std::max(segment.first.y, segment.second.y);


    // 权重矩阵必须是单精度浮点类型
    CV_Assert(weights.type() == CV_32FC1);
    // 检查线段水平还是垂直
    CV_Assert(xMin == xMax || yMin == yMax);


    // 如果是垂直线段并且权重矩阵的第二个元素不为0
    if (xMin == xMax && weights.at<float>(1) != 0)
    {
        x = xMin;
        // 根据直线方程计算交点的y值
        y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
        // 检测交点是否在图像的边界内
        if (y >= yMin && y <= yMax)
        {
            crossPoint.x = x;
            crossPoint.y = y;
            return true;
        }
    }
    // 如果是水平线段并且权重矩阵的第一个元素不为0
    else if (yMin == yMax && weights.at<float>(0) != 0)
    {
        y = yMin;
        // 根据直线方程计算交点的x值
        x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
        // 检测交点是否在图像的边界内
        if (x >= xMin && x <= xMax)
        {
            crossPoint.x = x;
            crossPoint.y = y;
            return true;
        }
    }
    return false;
}


// 函数findPointsForLine用于计算用于绘制直线的两个点
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
{
    // 如果权重矩阵为空,则返回失败
    if (weights.empty())
    {
        return false;
    }


    // 记录已找到的有效交点数量
    int foundPointsCount = 0;
    // 用于存储图像边框的4条线段(轮廓线)
    std::vector<std::pair<Point,Point> > segments;
    // 初始化线段(轮廓线)
    fillSegments(segments, width, height);


    // 遍历所有的边框线
    for (uint i = 0; i < segments.size(); i++)
    {
        // 如果找到与边框的交点
        if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
            foundPointsCount++; //增加有效交点的数量
        // 如果已找到了两个有效交点,则可以构成一条线,跳出循环
        if (foundPointsCount >= 2)
            break;
    }


    return true; //成功找到两个点
}
// 重新绘制数据集和分割线的函数实现
void redraw(Data data, const Point points[2])
{
    data.img.setTo(0);                             // 将图像设置为全黑
    Point center;
    int radius = 3;                               // 点的半径
    Scalar color;                                 // 颜色
    CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1)); // 断言样本和响应数据类型正确
    for (int i = 0; i < data.samples.rows; i++)   // 遍历所有样本
    {
        center.x = static_cast<int>(data.samples.at<float>(i,0));
        center.y = static_cast<int>(data.samples.at<float>(i,1));
        color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128); // 根据响应值设定颜色
        circle(data.img, center, radius, color, 5); // 绘制圆形点
    }
    line(data.img, points[0], points[1],cv::Scalar(1,255,1)); // 绘制分割线


    imshow("Train svmsgd", data.img);             // 显示图像
}


// 添加训练点,重新训练 SVMSGD 算法并在图像上绘制结果的函数实现
// 函数addPointRetrainAndRedraw用于添加新的训练点,并重新训练SVMSGD算法,然后重绘图形
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{
    // 创建一个1行2列的单精度浮点数矩阵,用于存储一个样本点
    Mat currentSample(1, 2, CV_32FC1);


    // 设定样本点的横纵坐标
    currentSample.at<float>(0,0) = (float)x;
    currentSample.at<float>(0,1) = (float)y;
    // 将新样本点加入到样本集中
    data.samples.push_back(currentSample);
    // 将样本点的响应(类别)加入到响应集中
    data.responses.push_back(static_cast<float>(response));


    // 创建权重矩阵和偏移量
    Mat weights(1, 2, CV_32FC1);
    float shift = 0;


    // 如果训练成功
    if (doTrain(data.samples, data.responses, weights, shift))
    {
        // 创建Points数组用于存储线的两个点
        Point points[2];
        // 找到用于绘制直线的两个点
        findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
        // 使用找到的两个点重新绘制图形
        redraw(data, points);
    }
}


// 鼠标回调函数,用于在图像上添加正负样本点并重新训练模型
static void onMouse( int event, int x, int y, int, void* pData)
{
    Data &data = *(Data*)pData;                  // 从pData转换获取Data结构体引用


    switch( event )                              // 根据事件类型
    {
    case EVENT_LBUTTONUP:                        // 左键松开事件
        addPointRetrainAndRedraw(data, x, y, 1); // 添加正样本点并重新训练绘制
        break;


    case EVENT_RBUTTONDOWN:                      // 右键按下事件
        addPointRetrainAndRedraw(data, x, y, -1);// 添加负样本点并重新训练绘制
        break;
    }
}


// 主函数
int main()
{
    Data data;                                   // 创建Data结构体实例


    setMouseCallback( "Train svmsgd", onMouse, &data ); // 设置鼠标回调函数
    waitKey();                                   // 等待按键


    return 0;                                    // 程序结束
}

该段代码是一个关于OpenCV和机器学习算法SVMSGD(支持向量机随机梯度下降)的简单示例,用于创建一个可交互的界面,在上面添加样本点,进行实时的线性分类器训练,并且通过绘制决策边界来显示分类结果。通过鼠标左键添加正样本,右键添加负样本。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值