【opencv450-samples】train_svmsgd.cpp

46 篇文章 4 订阅

 

与SVM不同,SVMSGD不需要设置核函数。

【参数】默认值见下述代码

模型类型:SGD、ASGD(推荐)。随机梯度下降、平均随机梯度下降。
边界类型:HARD_MARGIN、SOFT_MARGIN(推荐),前者用于线性可分,后者用于非线性可分
边界规范化 lambda:推荐设为0.0001(对于SGD),0.00001(对于ASGD)。越小,异类被抛弃的越少。
步长 gamma_0
步长降低力度 c:推荐设置为1(对于SGD),0.75(对于ASGD)
终止条件:TermCriteria::COUNT、TermCriteria::EPS、TermCriteria::COUNT + TermCriteria::EPS

参数设置函数:

setSvmsgdType()
setMarginType()
setMarginRegularization()
setInitialStepSize()
setStepDecreasingPower()

【使用方式】

cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();//创建对象
svmsgd->train(trainData);//训练
svmsgd->save("MySvmsgd.xml");//保存模型
svmsgd->load("MySvmsgd.xml");//加载模型
svmsgd->predict(samples, responses);//预测,结果保存到responses标签中

#include "opencv2/core.hpp"
#include "opencv2/video/tracking.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/ml.hpp"

using namespace cv;
using namespace cv::ml;

//https://www.cnblogs.com/xixixing/p/12430202.html
struct Data
{
    Mat img;
    Mat samples;          //一组训练样本。 包含图像上的点Set of train samples. Contains points on image
    Mat responses;        //训练样本的标签 Set of responses for train samples

    Data() //显示图像
    {
        const int WIDTH = 841;
        const int HEIGHT = 594;
        img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
        imshow("Train svmsgd", img);
    }
};

//Train with SVMSGD algorithm
//(samples, responses) is a train set
//weights is a required vector for decision function of SVMSGD algorithm
//用SVMSGD算法训练
//(samples,responses) 是一个训练集
//weights 是 SVMSGD 算法决策函数所需的向量
bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);

//function finds two points for drawing line (wx = 0)
//函数找到绘制线的两个点(wx = 0)
bool findPointsForLine(const Mat &weights, float shift, Point points[], int width, int height);

// function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
// 函数找到线 (wx = 0) 和线段 ( (y = HEIGHT, 0 <= x <= WIDTH) 或 (x = WIDTH, 0 <= y <= HEIGHT) ) 的交叉点
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);

//segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
//线段的初始化 ( (y = HEIGHT, 0 <= x <= WIDTH) 和 (x = WIDTH, 0 <= y <= HEIGHT) )
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);

//redraw points' set and line (wx = 0)
//重绘点的集合和线(wx = 0)
void redraw(Data data, const Point points[2]);

//add point in train set, train SVMSGD algorithm and draw results on image
//在训练集中添加点,训练SVMSGD算法并在图像上绘制结果
void addPointRetrainAndRedraw(Data &data, int x, int y, int response);

//训练  得到参数
bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
{
    cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
    //*设置参数,以下全是默认参数
    //svmsgd->setSvmsgdType(SVMSGD::ASGD); //模型类型
    //svmsgd->setMarginType(SVMSGD::SOFT_MARGIN); //边界类型
    //svmsgd->setMarginRegularization(0.00001); //边界规范化
    //svmsgd->setInitialStepSize(0.05);//步长
    //svmsgd->setStepDecreasingPower(0.75); //步长减弱力度
    //svmsgd->setTermCriteria(TermCriteria(TermCriteria::COUNT,1000,1e-3));//终止条件,1000次迭代,0.001每次迭代的精度
    cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);//构造训练数据
    svmsgd->train( trainData );

    if (svmsgd->isTrained())
    {
        weights = svmsgd->getWeights();
        shift = svmsgd->getShift();
        //*保存模型
        svmsgd->save("svmsgd.xml"); //保存训练好的模型
        return true;
    }
    return false;
}
//找出边界四条直线
void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
{
    std::pair<Point,Point> currentSegment;//当前线段

    currentSegment.first = Point(width, 0);//右上角点
    currentSegment.second = Point(width, height);//右下角点
    segments.push_back(currentSegment);

    currentSegment.first = Point(0, height);//左下角点
    currentSegment.second = Point(width, height);//右下角点
    segments.push_back(currentSegment);

    currentSegment.first = Point(0, 0);//左上角点
    currentSegment.second = Point(width, 0);//右上角点
    segments.push_back(currentSegment);

    currentSegment.first = Point(0, 0);
    currentSegment.second = Point(0, height);
    segments.push_back(currentSegment);
}

//找到与边界框交点
bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
{
    int x = 0;
    int y = 0;
    int xMin = std::min(segment.first.x, segment.second.x);
    int xMax = std::max(segment.first.x, segment.second.x);
    int yMin = std::min(segment.first.y, segment.second.y);
    int yMax = std::max(segment.first.y, segment.second.y);

    CV_Assert(weights.type() == CV_32FC1);
    CV_Assert(xMin == xMax || yMin == yMax);//断言:线段为垂直或者水平
    //一条垂直线      边框的左侧和右侧线    
    if (xMin == xMax && weights.at<float>(1) != 0) //AX+BY+C=0  B!=0
    {
        x = xMin;
        y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
        if (y >= yMin && y <= yMax)
        { //直线与边框左右侧线条的交点
            crossPoint.x = x;
            crossPoint.y = y;
            return true;
        }
    }
    //一条水平线  边框的上侧和下侧线    
    else if (yMin == yMax && weights.at<float>(0) != 0)//A!=0
    {
        y = yMin;
        x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
        if (x >= xMin && x <= xMax)
        {  //直线与边框上下端线条的交点
            crossPoint.x = x;
            crossPoint.y = y;
            return true;
        }
    }
    return false;
}
//根据直线找到与边界框的交点  2个
bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
{
    if (weights.empty())//直线权重参数非空
    {
        return false;
    }

    int foundPointsCount = 0;//找到的点数
    std::vector<std::pair<Point,Point> > segments;//点对集合    线段集合
    fillSegments(segments, width, height);//找到边界框

    for (uint i = 0; i < segments.size(); i++)
    {   //找到直线与边界框的交点
        if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
            foundPointsCount++;//直线与边界框交点数
        if (foundPointsCount >= 2)
            break;
    }

    return true;
}
//绘制直线
void redraw(Data data, const Point points[2])
{
    data.img.setTo(0);//黑色背景
    Point center;//样本中心点
    int radius = 3;//半径3
    Scalar color;
    CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));//断言:数据样本类型
    for (int i = 0; i < data.samples.rows; i++)//遍历样本
    {
        center.x = static_cast<int>(data.samples.at<float>(i,0));
        center.y = static_cast<int>(data.samples.at<float>(i,1));
        color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
        circle(data.img, center, radius, color, 5);//绘制样本点
    }
    line(data.img, points[0], points[1],cv::Scalar(1,255,1));//绘制直线

    imshow("Train svmsgd", data.img);//显示图像
}
//添加点  标签response:1 / -1
void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
{
    Mat currentSample(1, 2, CV_32FC1);//临时点坐标 x,y   float 

    currentSample.at<float>(0,0) = (float)x;
    currentSample.at<float>(0,1) = (float)y;
    data.samples.push_back(currentSample);//添加到数据样本中
    data.responses.push_back(static_cast<float>(response));//添加到数据标签中

    Mat weights(1, 2, CV_32FC1);//权重系数A,B     超平面: AX+BY+C=0
    float shift = 0;//C
    //训练,得到超平面即直线参数
    if (doTrain(data.samples, data.responses, weights, shift))
    {
        Point points[2];
        findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);//找到直线与边界框的交点

        redraw(data, points);//绘制直线和样本点
    }
}

//鼠标回调
static void onMouse( int event, int x, int y, int, void* pData)
{
    Data &data = *(Data*)pData;//数据指针

    switch( event )
    {
    case EVENT_LBUTTONUP:
        addPointRetrainAndRedraw(data, x, y, 1);//左键 添加点标签1
        break;

    case EVENT_RBUTTONDOWN:
        addPointRetrainAndRedraw(data, x, y, -1);//右键 添加点标签-1
        break;
    }

}

int main()
{
    Data data;

    setMouseCallback( "Train svmsgd", onMouse, &data );
    waitKey();

    return 0;
}

 svmsgd.xml

参考:

基于SGD、ASGD算法的SVM分类器(OpenCV案例源码train_svmsgd.cpp解读) - 夕西行 - 博客园

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值