基于OpenCV3.x的SVM字符分类源码及其分析

SVM是一种二分类算法,当然,也可以去做分类。OpenCVSVM已经提供了多分类,不需要我们手动的去写代码。我们可以将SVM用在数字和字符的分类上面, 比如我们已经获得一副图片,去判断这个字符到底是数字0-9中的哪一个?根据SVM的分类,我们可以准确地知道这副图片属于哪一类。本博客将提供完整的测试代码以及数据下载
由于涉及到机器学习以及样本的个数不同,不同的硬件将耗费的时间也不同。下面列出我试验的电脑硬件配置

CPU i7-8750H
RAM 8GB
硬盘 128G SSD
显卡 GTX 1050Ti

我在训练过程中,CPU使用率大约是20%,耗费时间大约是30s。如果是i5的计算机,耗费时间将会很长,请耐心等待。
数据集来自以下的csdn下载,感谢这位朋友提供的数据以便于个人的测试。
为了方便测试,我们只提供了数字0-9 ,字母GTWX分类,也就是说,随便拿一张图片,将其分其14类中的某一类去。
由于这位朋友的数据在每一类中的图片个数不一样,所以,我每一类中抽取200张图片作为测试,剩下的图片全部作为样本。
这里写图片描述

我们先上个结果吧,分别显示了训练阶段的每一类的训练样本个数,并显示了测试集的准确率。从结果可知,准确率很难达到100%。后期,我们将提供所有的工程代码下载地址。

// main.cpp


#include <iostream>
#include "opencv2/opencv.hpp"
int main()
{
    void test_CV3CharRecog();
    double start = cv::getTickCount();
    test_CV3CharRecog();
    double end = cv::getTickCount();
    double time = (end - start) / cv::getTickFrequency();
    std::cout << "time = " << time << std::endl;

    system("pause");
    return 0;
}

上面是主函数,接下来提供的是主要函数调用

//test_cv3_char_recog.cpp

/**
* 训练阶段:
* [1] 批量导入图片.
* 本Demo自带一个批量导入图片的函数, 也可以采用Qt等接口
* 一个目录下仅一个类的图片,比如数字2,不能有其他类的图片.
* 生成的数据为vector<Mat>
*
* [2] 将每一类的图片和标签添加到训练数据中。
* void CV3SVMOperate::AddTrainSample(std::vector<cv::Mat>& imgs, int labelValue);
* 其中imgs为[1]的生成数据,而labelValue为CharsLabelValue的静态成员变量。
* 比如CharsLabelValue::NUM_ZERO,严禁使用数字代替
*
* [3] 设置SVM的参数和HOG参数,如果构造函数中已经指明,则不需要设置
*
* [4] 训练.
* void CV3SVMOperate::train();
* void CV3SVMOperate::train(const std::string& savePath);
* 上一个不保存模型文件,仅仅用来测试使用.
* 下面一个保存模型文件.
*
*
* 
* 预测阶段:
* [1] 通过构造函数指明模型路径
* CV3SVMOperate::CV3SVMOperate(const std::string& modelPath);
*
* [2] 设置HOG参数
* void setHOG(cv::Size winSize, cv::Size blockSize, cv::Size blockStride, cv::Size cellSize, int nbins);
*
* [3] 预测
* int CV3SVMOperate::predict(cv::Mat& img);
* std::string CV3SVMOperate::predictCharName(cv::Mat& img);
* 第一个预测函数仅仅是返回预测的标签值
* 第二个预测函数返回标签值所对应的字符
*
*/
#include <iostream>
#include "opencv2/opencv.hpp"
#include "chars_info.h"
#include "opencv_3_svm_operate.h"
#include "opencv2410_inputoutput.h"
//#include "../ReadFilePaths/opencv2410_inputoutput.h"



/*
* 将dir路径下面所有的jpg图片不加任何处理的全部导入到imgsMat中, 并返回导入图片的个数
* dir: 目录的路径
* imgsMat: 存储所有图片的Mat
* return: 当前目录的图片个数
*/
static int LoadImgs(const std::string& dir, std::vector<cv::Mat>& imgsMat)
{
    imgsMat.clear();
    //Directory d;

    std::vector<std::string> filePaths = Directory::GetListFilesR(dir, "*.jpg", true);

    for (std::string str : filePaths)
    {
        cv::Mat img = cv::imread(str, cv::IMREAD_COLOR);
        cv::Mat tmp;
        cv::resize(img, tmp, cv::Size(16, 25));
        imgsMat.push_back(tmp);
    }
    //cv::waitKey(1);
    return imgsMat.size();
}


/*
* 将一个路径下的所有图片以及标签值添加到训练数据中
* cv3SVM: CV3SVMOperate对象, 内含有训练数据的成员变量
* path: 某一类图片的目录路径
* labelValue: 某一类图片的标签值. 必须是CharsLabelValue的静态成员.
*           如果不是所列出的某一类,可以手动添加,但是标签值不能重复
*/
static void AddOneDir(CV3SVMOperate& cv3SVM, std::string path, int labelValue)
{
    std::vector<cv::Mat> imgsMat;
    LoadImgs(path, imgsMat);
    cv3SVM.addTrainSample(imgsMat, labelValue);
    std::cout << "类别:" << CharsLabelValue::LabelValueToCharName(labelValue);
    std::cout << "\t样本个数:" << imgsMat.size() << std::endl;
}


/*
* 将测试集中的所有图片进行预测,与正确的符号名称进行比对,并比较信息
* imgs: 某一类目录下的所有图片
* cv3SVM: CV3SVMOperate对象
* charName: 正确的字符名称
*/
static void CompareResult(std::vector<cv::Mat>& imgs, CV3SVMOperate& cv3SVM, std::string charName)
{
    int cnt = 0;
    for (int i = 0; i < imgs.size(); i++)
    {
        cv::Mat img = imgs.at(i);
        std::string result = cv3SVM.predictCharName(img);
        if (result == charName)
        {
            cnt++;
        }
    }

    double ratio = double(cnt) / double(imgs.size());

    std::cout << "字符名称:";
    std::cout << charName << " ";
    std::cout << "测试样本数目为:" << imgs.size();
    std::cout << "\t成功个数:" << cnt;
    std::cout << "\t准确率为:" << ratio << std::endl;
}


static void test_OneDr(std::string& imgDir, CV3SVMOperate& cv3SVM, std::string result)
{
    std::vector<cv::Mat> imgsMat;
    LoadImgs(imgDir, imgsMat);
    CompareResult(imgsMat, cv3SVM, result);
}


static void test_trainHOG()
{
    //std::string prefix = "D:\\workspace\\vs2015_workspace\\AlgsBaseOpenCV\\Resources\\CharRecDemoImage\\trainData\\";
    std::string prefix = "../Resources/CharRecDemoImage/trainData/";
    std::string modelPath = "../Resources/b.xml";

    std::string num0 = prefix + "0";
    std::string num1 = prefix + "1";
    std::string num2 = prefix + "2";
    std::string num3 = prefix + "3";
    std::string num4 = prefix + "4";
    std::string num5 = prefix + "5";
    std::string num6 = prefix + "6";
    std::string num7 = prefix + "7";
    std::string num8 = prefix + "8";
    std::string num9 = prefix + "9";

    std::string alphaCapG = prefix + "G";
    std::string alphaCapT = prefix + "T";
    std::string alphaCapW = prefix + "W";
    std::string alphaCapX = prefix + "X";

    CV3SVMOperate cv3SVM;
    cv3SVM.setHOG(cv::Size(16, 24)/*winSize*/, cv::Size(8, 8)/*blockSize*/, cv::Size(2, 4)/*blockStride*/, cv::Size(4, 4)/*cellSize*/, 9/*nbins*/);
    cv3SVM.setSVM();

    // 添加训练数据
    std::cout << "添加训练数据开始!" << std::endl;
    AddOneDir(cv3SVM, num0, CharsLabelValue::NUM_ZERO);
    AddOneDir(cv3SVM, num1, CharsLabelValue::NUM_ONE);
    AddOneDir(cv3SVM, num2, CharsLabelValue::NUM_TWO);
    AddOneDir(cv3SVM, num3, CharsLabelValue::NUM_THREE);
    AddOneDir(cv3SVM, num4, CharsLabelValue::NUM_FOUR);
    AddOneDir(cv3SVM, num5, CharsLabelValue::NUM_FIVE);
    AddOneDir(cv3SVM, num6, CharsLabelValue::NUM_SIX);
    AddOneDir(cv3SVM, num7, CharsLabelValue::NUM_SEVEN);
    AddOneDir(cv3SVM, num8, CharsLabelValue::NUM_EIGHT);
    AddOneDir(cv3SVM, num9, CharsLabelValue::NUM_NINE);

    AddOneDir(cv3SVM, alphaCapG, CharsLabelValue::CHAR_CAPTITAL_G);
    AddOneDir(cv3SVM, alphaCapT, CharsLabelValue::CHAR_CAPTITAL_T);
    AddOneDir(cv3SVM, alphaCapW, CharsLabelValue::CHAR_CAPTITAL_W);
    AddOneDir(cv3SVM, alphaCapX, CharsLabelValue::CHAR_CAPTITAL_X);
    std::cout << "添加训练数据结束!" << std::endl;



    std::cout << "训练开始!" << std::endl;

    cv3SVM.train(modelPath);
    std::cout << "训练完成!" << std::endl;

    cv::waitKey(1);

}


static void test_Predict()
{
    std::string modelPath = "../Resources/b.xml";
    //std::string prefix = "D:\\workspace\\vs2015_workspace\\AlgsBaseOpenCV\\Resources\\CharRecDemoImage\\testData\\";
    std::string prefix = "../Resources/CharRecDemoImage/testData/";

    CV3SVMOperate cv3SVM(modelPath);        
    cv3SVM.setHOG(cv::Size(16, 24)/*winSize*/, cv::Size(8, 8)/*blockSize*/, cv::Size(2, 4)/*blockStride*/, cv::Size(4, 4)/*cellSize*/, 9/*nbins*/);


    std::string num0 = prefix + "0";
    std::string num1 = prefix + "1";
    std::string num2 = prefix + "2";
    std::string num3 = prefix + "3";
    std::string num4 = prefix + "4";
    std::string num5 = prefix + "5";
    std::string num6 = prefix + "6";
    std::string num7 = prefix + "7";
    std::string num8 = prefix + "8";
    std::string num9 = prefix + "9";

    std::string alphaCapG = prefix + "G";
    std::string alphaCapT = prefix + "T";
    std::string alphaCapW = prefix + "W";
    std::string alphaCapX = prefix + "X";


    test_OneDr(num0, cv3SVM, "0");
    test_OneDr(num1, cv3SVM, "1");
    test_OneDr(num2, cv3SVM, "2");
    test_OneDr(num3, cv3SVM, "3");
    test_OneDr(num4, cv3SVM, "4");
    test_OneDr(num5, cv3SVM, "5");
    test_OneDr(num6, cv3SVM, "6");
    test_OneDr(num7, cv3SVM, "7");
    test_OneDr(num8, cv3SVM, "8");
    test_OneDr(num9, cv3SVM, "9");

    test_OneDr(alphaCapG, cv3SVM, "G");
    test_OneDr(alphaCapT, cv3SVM, "T");
    test_OneDr(alphaCapW, cv3SVM, "W");
    test_OneDr(alphaCapX, cv3SVM, "X");
}


void test_CV3CharRecog()
{
    void test_Predict();
    void test_trainHOG();

    std::cout << "训练阶段!" << std::endl;
    test_trainHOG();            // 训练

    std::cout << "测试训练模型!" << std::endl;
    std::cout << "---------------------------" << std::endl;
    std::cout << "测试开始!" << std::endl;
    test_Predict();             // 测试
    std::cout << "测试结束!" << std::endl;

    cv::waitKey(1);
}

接下来分别提供的是主要头文件:

// chars_info.h

#ifndef CHARS_INFO_H
#define CHARS_INFO_H

#include <string>

class CharsLabelValue
{
public:
    static std::string LabelValueToCharName(int v);
public:
    const static int NUM_ZERO;
    const static int NUM_ONE;
    const static int NUM_TWO;
    const static int NUM_THREE;
    const static int NUM_FOUR;
    const static int NUM_FIVE;
    const static int NUM_SIX;
    const static int NUM_SEVEN;
    const static int NUM_EIGHT;
    const static int NUM_NINE;

    const static int CHAR_CAPTITAL_A;   // 大写字母
    const static int CHAR_CAPTITAL_B;
    const static int CHAR_CAPTITAL_C;
    const static int CHAR_CAPTITAL_D;
    const static int CHAR_CAPTITAL_E;
    const static int CHAR_CAPTITAL_F;
    const static int CHAR_CAPTITAL_G;
    const static int CHAR_CAPTITAL_H;
    const static int CHAR_CAPTITAL_I;
    const static int CHAR_CAPTITAL_J;
    const static int CHAR_CAPTITAL_K;
    const static int CHAR_CAPTITAL_L;
    const static int CHAR_CAPTITAL_M;
    const static int CHAR_CAPTITAL_N;
    const static int CHAR_CAPTITAL_O;
    const static int CHAR_CAPTITAL_P;
    const static int CHAR_CAPTITAL_Q;
    const static int CHAR_CAPTITAL_R;
    const static int CHAR_CAPTITAL_S;
    const static int CHAR_CAPTITAL_T;
    const static int CHAR_CAPTITAL_U;
    const static int CHAR_CAPTITAL_V;
    const static int CHAR_CAPTITAL_W;
    const static int CHAR_CAPTITAL_X;
    const static int CHAR_CAPTITAL_Y;
    const static int CHAR_CAPTITAL_Z;


    // 小写字母
    const static int CHAR_LOWERCASE_A;
    const static int CHAR_LOWERCASE_B;
    const static int CHAR_LOWERCASE_C;
    const static int CHAR_LOWERCASE_D;
    const static int CHAR_LOWERCASE_E;
    const static int CHAR_LOWERCASE_F;
    const static int CHAR_LOWERCASE_G;
    const static int CHAR_LOWERCASE_H;
    const static int CHAR_LOWERCASE_I;
    const static int CHAR_LOWERCASE_J;
    const static int CHAR_LOWERCASE_K;
    const static int CHAR_LOWERCASE_L;
    const static int CHAR_LOWERCASE_M;
    const static int CHAR_LOWERCASE_N;
    const static int CHAR_LOWERCASE_O;
    const static int CHAR_LOWERCASE_P;
    const static int CHAR_LOWERCASE_Q;
    const static int CHAR_LOWERCASE_R;
    const static int CHAR_LOWERCASE_S;
    const static int CHAR_LOWERCASE_T;
    const static int CHAR_LOWERCASE_U;
    const static int CHAR_LOWERCASE_V;
    const static int CHAR_LOWERCASE_W;
    const static int CHAR_LOWERCASE_X;
    const static int CHAR_LOWERCASE_Y;
    const static int CHAR_LOWERCASE_Z;

private:
    CharsLabelValue();
};

#endif // !CHARS_INFO_H
// opencv_3_svm_operate.h

#ifndef OPENCV_3_SVM_OPERATE_H
#define OPENCV_3_SVM_OPERATE_H

#include <string>
#include <vector>
//#include "opencv2/opencv.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/ml/ml.hpp"
#include "opencv2/objdetect/objdetect.hpp"

class CV3SVMOperate
{
private:
    cv::Ptr<cv::ml::SVM> m_svm;         // SVM模型
    cv::HOGDescriptor m_hog;            // HOG特征
    std::vector<int> m_allLabels;       // 训练样本所对应的标签
    std::vector<cv::Mat> m_allGrads;    // 所有图片的梯度矩阵

public:
    CV3SVMOperate();

    /*
    * 构造函数
    * modelPath: 训练后的模型文件路径
    */
    CV3SVMOperate(const std::string& modelPath);

    /*
    * 所有参数必须符合以下条件
    * (winSize.width - blockSize.width) % blockStride.width == 0
    * (winSize.height - blockSize.height) % blockStride.height == 0
    * (blockSize.width % cellSize.width) == 0
    * (blockSize.height % cellSize.height) == 0
    */
    void setHOG(cv::Size winSize, cv::Size blockSize, cv::Size blockStride, cv::Size cellSize, int nbins);

    /*
    * 设置SVM的属性
    */
    void setSVM();                      

    /*
    * 训练;
    * 不保存模型文件
    */
    void train();

    /*
    * 训练;
    * 保存模型文件
    * savePath: svm训练后的模型文件路径
    */
    void train(const std::string& savePath);

    /*
    * 根据模型,判定输入图片的类别,类别用int数值表示
    * img: 输入图片
    * return:图片所属的类; 仅仅是数字
    */
    int predict(cv::Mat& img);

    /*
    * 返回所预测的字符名称
    * img: 输入的原始图片
    * return: string类型的字符
    */
    std::string predictCharName(cv::Mat& img);

    /*
    * 将某一类所有的图片以及所属的标签添加到训练数据中去
    * imgs: 所有的Mat类型图片
    * labelValue: imgs所属的标签值. 标签值参考CharsLabelValue类里面的静态变量
    */
    void addTrainSample(std::vector<cv::Mat>& imgs, int labelValue);
};

#endif // !OPENCV_3_SVM_OPERATE_H

针对上述的头文件具体实现:

// chars_info.cpp
#include "chars_info.h"


// 数字初始化
const int CharsLabelValue::NUM_ZERO  = 0;
const int CharsLabelValue::NUM_ONE = 1;
const int CharsLabelValue::NUM_TWO = 2;
const int CharsLabelValue::NUM_THREE = 3;
const int CharsLabelValue::NUM_FOUR = 4;
const int CharsLabelValue::NUM_FIVE = 5;
const int CharsLabelValue::NUM_SIX = 6;
const int CharsLabelValue::NUM_SEVEN = 7;
const int CharsLabelValue::NUM_EIGHT = 8;
const int CharsLabelValue::NUM_NINE = 9;


// 大写字母标签值
const int CharsLabelValue::CHAR_CAPTITAL_A = 10;
const int CharsLabelValue::CHAR_CAPTITAL_B = 11;
const int CharsLabelValue::CHAR_CAPTITAL_C = 12;
const int CharsLabelValue::CHAR_CAPTITAL_D = 13;
const int CharsLabelValue::CHAR_CAPTITAL_E = 14;
const int CharsLabelValue::CHAR_CAPTITAL_F = 15;
const int CharsLabelValue::CHAR_CAPTITAL_G = 16;
const int CharsLabelValue::CHAR_CAPTITAL_H = 17;
const int CharsLabelValue::CHAR_CAPTITAL_I = 18;
const int CharsLabelValue::CHAR_CAPTITAL_J = 19;
const int CharsLabelValue::CHAR_CAPTITAL_K = 20;
const int CharsLabelValue::CHAR_CAPTITAL_L = 21;
const int CharsLabelValue::CHAR_CAPTITAL_M = 22;
const int CharsLabelValue::CHAR_CAPTITAL_N = 23;
const int CharsLabelValue::CHAR_CAPTITAL_O = 24;
const int CharsLabelValue::CHAR_CAPTITAL_P = 25;
const int CharsLabelValue::CHAR_CAPTITAL_Q = 26;
const int CharsLabelValue::CHAR_CAPTITAL_R = 27;
const int CharsLabelValue::CHAR_CAPTITAL_S = 28;
const int CharsLabelValue::CHAR_CAPTITAL_T = 29;
const int CharsLabelValue::CHAR_CAPTITAL_U = 30;
const int CharsLabelValue::CHAR_CAPTITAL_V = 31;
const int CharsLabelValue::CHAR_CAPTITAL_W = 32;
const int CharsLabelValue::CHAR_CAPTITAL_X = 33;
const int CharsLabelValue::CHAR_CAPTITAL_Y = 34;
const int CharsLabelValue::CHAR_CAPTITAL_Z = 35;

// 小写字母
const int CharsLabelValue::CHAR_LOWERCASE_A = 36;
const int CharsLabelValue::CHAR_LOWERCASE_B = 37;
const int CharsLabelValue::CHAR_LOWERCASE_C = 38;
const int CharsLabelValue::CHAR_LOWERCASE_D = 39;
const int CharsLabelValue::CHAR_LOWERCASE_E = 40;
const int CharsLabelValue::CHAR_LOWERCASE_F = 41;
const int CharsLabelValue::CHAR_LOWERCASE_G = 42;
const int CharsLabelValue::CHAR_LOWERCASE_H = 43;
const int CharsLabelValue::CHAR_LOWERCASE_I = 44;
const int CharsLabelValue::CHAR_LOWERCASE_J = 45;
const int CharsLabelValue::CHAR_LOWERCASE_K = 46;
const int CharsLabelValue::CHAR_LOWERCASE_L = 47;
const int CharsLabelValue::CHAR_LOWERCASE_M = 48;
const int CharsLabelValue::CHAR_LOWERCASE_N = 49;
const int CharsLabelValue::CHAR_LOWERCASE_O = 50;
const int CharsLabelValue::CHAR_LOWERCASE_P = 51;
const int CharsLabelValue::CHAR_LOWERCASE_Q = 52;
const int CharsLabelValue::CHAR_LOWERCASE_R = 53;
const int CharsLabelValue::CHAR_LOWERCASE_S = 54;
const int CharsLabelValue::CHAR_LOWERCASE_T = 55;
const int CharsLabelValue::CHAR_LOWERCASE_U = 56;
const int CharsLabelValue::CHAR_LOWERCASE_V = 57;
const int CharsLabelValue::CHAR_LOWERCASE_W = 58;
const int CharsLabelValue::CHAR_LOWERCASE_X = 59;
const int CharsLabelValue::CHAR_LOWERCASE_Y = 60;
const int CharsLabelValue::CHAR_LOWERCASE_Z = 61;



/*
* 返回预测值所对应的字符名称
* v: 值
* return: 字符名称
*/
std::string CharsLabelValue::LabelValueToCharName(int v)
{
    /*
    * 对于不需要判断的字符,可以注释掉判断代码
    * 或者是对需要分类的字符调整一下顺序,避免速度问题
    */

    // 判断是哪个数字
    if (v == CharsLabelValue::NUM_ZERO)     return "0";
    if (v == CharsLabelValue::NUM_ONE)      return "1";
    if (v == CharsLabelValue::NUM_TWO)      return "2";
    if (v == CharsLabelValue::NUM_THREE)    return "3";
    if (v == CharsLabelValue::NUM_FOUR)     return "4";
    if (v == CharsLabelValue::NUM_FIVE)     return "5";
    if (v == CharsLabelValue::NUM_SIX)      return "6";
    if (v == CharsLabelValue::NUM_SEVEN)    return "7";
    if (v == CharsLabelValue::NUM_EIGHT)    return "8";
    if (v == CharsLabelValue::NUM_NINE)     return "9";

    // 判断是哪个大写字符
    if (v == CharsLabelValue::CHAR_CAPTITAL_A) return "A";
    if (v == CharsLabelValue::CHAR_CAPTITAL_B) return "B";
    if (v == CharsLabelValue::CHAR_CAPTITAL_C) return "C";
    if (v == CharsLabelValue::CHAR_CAPTITAL_D) return "D";
    if (v == CharsLabelValue::CHAR_CAPTITAL_E) return "E";
    if (v == CharsLabelValue::CHAR_CAPTITAL_F) return "F";
    if (v == CharsLabelValue::CHAR_CAPTITAL_G) return "G";
    if (v == CharsLabelValue::CHAR_CAPTITAL_H) return "H";
    if (v == CharsLabelValue::CHAR_CAPTITAL_I) return "I";
    if (v == CharsLabelValue::CHAR_CAPTITAL_J) return "J";
    if (v == CharsLabelValue::CHAR_CAPTITAL_K) return "K";
    if (v == CharsLabelValue::CHAR_CAPTITAL_L) return "L";
    if (v == CharsLabelValue::CHAR_CAPTITAL_M) return "M";
    if (v == CharsLabelValue::CHAR_CAPTITAL_N) return "N";
    if (v == CharsLabelValue::CHAR_CAPTITAL_O) return "O";
    if (v == CharsLabelValue::CHAR_CAPTITAL_P) return "P";
    if (v == CharsLabelValue::CHAR_CAPTITAL_Q) return "Q";
    if (v == CharsLabelValue::CHAR_CAPTITAL_R) return "R";
    if (v == CharsLabelValue::CHAR_CAPTITAL_S) return "S";
    if (v == CharsLabelValue::CHAR_CAPTITAL_T) return "T";
    if (v == CharsLabelValue::CHAR_CAPTITAL_U) return "U";
    if (v == CharsLabelValue::CHAR_CAPTITAL_V) return "V";
    if (v == CharsLabelValue::CHAR_CAPTITAL_W) return "W";
    if (v == CharsLabelValue::CHAR_CAPTITAL_X) return "X";
    if (v == CharsLabelValue::CHAR_CAPTITAL_Y) return "Y";
    if (v == CharsLabelValue::CHAR_CAPTITAL_Z) return "Z";

    // 判断小写字符
    if (v == CharsLabelValue::CHAR_LOWERCASE_A) return "a";
    if (v == CharsLabelValue::CHAR_LOWERCASE_B) return "b";
    if (v == CharsLabelValue::CHAR_LOWERCASE_C) return "c";
    if (v == CharsLabelValue::CHAR_LOWERCASE_D) return "d";
    if (v == CharsLabelValue::CHAR_LOWERCASE_E) return "e";
    if (v == CharsLabelValue::CHAR_LOWERCASE_F) return "f";
    if (v == CharsLabelValue::CHAR_LOWERCASE_G) return "g";
    if (v == CharsLabelValue::CHAR_LOWERCASE_H) return "h";
    if (v == CharsLabelValue::CHAR_LOWERCASE_I) return "i";
    if (v == CharsLabelValue::CHAR_LOWERCASE_J) return "j";
    if (v == CharsLabelValue::CHAR_LOWERCASE_K) return "k";
    if (v == CharsLabelValue::CHAR_LOWERCASE_L) return "l";
    if (v == CharsLabelValue::CHAR_LOWERCASE_M) return "m";
    if (v == CharsLabelValue::CHAR_LOWERCASE_N) return "n";
    if (v == CharsLabelValue::CHAR_LOWERCASE_O) return "o";
    if (v == CharsLabelValue::CHAR_LOWERCASE_P) return "p";
    if (v == CharsLabelValue::CHAR_LOWERCASE_Q) return "q";
    if (v == CharsLabelValue::CHAR_LOWERCASE_R) return "r";
    if (v == CharsLabelValue::CHAR_LOWERCASE_S) return "s";
    if (v == CharsLabelValue::CHAR_LOWERCASE_T) return "t";
    if (v == CharsLabelValue::CHAR_LOWERCASE_U) return "u";
    if (v == CharsLabelValue::CHAR_LOWERCASE_V) return "v";
    if (v == CharsLabelValue::CHAR_LOWERCASE_W) return "w";
    if (v == CharsLabelValue::CHAR_LOWERCASE_X) return "x";
    if (v == CharsLabelValue::CHAR_LOWERCASE_Y) return "y";
    if (v == CharsLabelValue::CHAR_LOWERCASE_Z) return "z";


    // 不属于任何以上条件的,均返回空
    return "";
}
#include <string>
#include <vector>
#include "opencv_3_svm_operate.h"
#include "chars_info.h"


/*
* gradientLst: 存储所有的梯度容器
* trainData: 得到的训练矩阵
*/
static void ConvertMLData(std::vector<cv::Mat>& gradientLst, cv::Mat& trainData)
{
    const int ROWS = (int)gradientLst.size();
    const int COLS = (int)std::max(gradientLst[0].cols, gradientLst[0].rows);

    cv::Mat tmp(1, COLS, CV_32FC1); // 用来存放转制的矩阵
    trainData = cv::Mat(ROWS, COLS, CV_32FC1);

    // 数据按照行来排列
    for (int i = 0; i < gradientLst.size(); i++)
    {
        cv::Mat m = gradientLst.at(i);
        if (m.cols == 1)
        {
            cv::transpose(m, tmp);
            tmp.copyTo(trainData.row(i));
        }
        else if (m.rows == 1)
        {
            m.copyTo(trainData.row(i));
        }
    }
}


/*
* 返回所预测的字符名称
* img: 输入的原始图片
* return: string类型的字符
*/
std::string CV3SVMOperate::predictCharName(cv::Mat& img)
{
    int value = this->predict(img);
    return CharsLabelValue::LabelValueToCharName(value);
}


/*
* 所有参数必须符合以下条件
* (winSize.width - blockSize.width) % blockStride.width == 0
* (winSize.height - blockSize.height) % blockStride.height == 0
* (blockSize.width % cellSize.width) == 0
* (blockSize.height % cellSize.height) == 0
*/
void CV3SVMOperate::setHOG(cv::Size winSize, cv::Size blockSize, cv::Size blockStride, cv::Size cellSize, int nbins)
{
    // HOG特征
    m_hog = cv::HOGDescriptor(winSize, blockSize, blockStride, cellSize, nbins);
}


CV3SVMOperate::CV3SVMOperate():m_svm(cv::ml::SVM::create()), m_hog(cv::HOGDescriptor())
{
}


/*
* 构造函数
* modelPath: 训练后的模型文件路径
*/
CV3SVMOperate::CV3SVMOperate(const std::string& modelPath)
{
    m_svm = cv::Algorithm::load<cv::ml::SVM>(modelPath);
}

/*
* 设置SVM的属性
*/
void CV3SVMOperate::setSVM()
{
    cv::TermCriteria tc = cv::TermCriteria(cv::TermCriteria::MAX_ITER + cv::TermCriteria::EPS, 1000/*maxCount*/, FLT_EPSILON);

    m_svm->setType(cv::ml::SVM::Types::C_SVC);              // 设置SVM类型
    m_svm->setKernel(cv::ml::SVM::KernelTypes::RBF);        //设置核类型
    m_svm->setDegree(10.0);
    m_svm->setGamma(0.09);                                  // 设置gamma值,默认为1.0
    m_svm->setCoef0(1.0);
    m_svm->setC(10.0);
    m_svm->setNu(0.5);
    m_svm->setP(1.0);
    m_svm->setTermCriteria(tc);
}


/*
* 训练;
* 不保存模型文件
*/
void CV3SVMOperate::train()
{
    // 将数据转换成标准的svm类型
    cv::Mat trainData;
    ConvertMLData(m_allGrads, trainData);
    cv::Mat labels = cv::Mat(m_allLabels);

    m_svm->train(trainData, cv::ml::ROW_SAMPLE, labels);
}


/*
* 训练;
* 保存模型文件
* savePath: svm训练后的模型文件路径
*/
void CV3SVMOperate::train(const std::string& savePath)
{
    // 训练之后并保存模型文件
    this->train();
    m_svm->save(savePath);
}


/*
* 根据模型,判定输入图片的类别,类别用int数值表示
* img: 输入图片
* return:图片所属的类; 仅仅是数字
*/
int CV3SVMOperate::predict(cv::Mat& img)
{
    // 这里的预测仅仅是标签值,并没有指定是哪一类. 
    std::vector<float> descriptors;
    cv::Size winStride(0, 0);
    m_hog.compute(img, descriptors, winStride);     // 计算梯度

    cv::Mat input;
    cv::Mat(descriptors).copyTo(input);
    cv::transpose(input, input);
    // 预测
    int ret = this->m_svm->predict(input);
    return ret;
}


/*
* 将某一类所有的图片以及所属的标签添加到训练数据中去
* imgs: 所有的Mat类型图片
* labelValue: imgs所属的标签值. 标签值参考CharsLabelValue类里面的静态变量
*/
void CV3SVMOperate::addTrainSample(std::vector<cv::Mat>& imgs, int labelValue)
{
    std::vector<cv::Mat> curAllGrads;       // 存储当前所有的梯度

    // imgs是预处理后的图像.现在计算梯度
    for (int i = 0; i < imgs.size(); i++)
    {
        cv::Mat m = imgs.at(i);
        std::vector<float> descriptors;
        cv::Size winStride(0, 0);
        m_hog.compute(m, descriptors, winStride);       // 计算梯度
        curAllGrads.push_back(cv::Mat(descriptors).clone());
    }

    // 将当前的梯度补充到前面所有的梯度中。 
    m_allGrads.insert(m_allGrads.end(), curAllGrads.begin(), curAllGrads.end());

    // 标签也要添加
    std::vector<int> curLabels(curAllGrads.size(), labelValue);
    m_allLabels.insert(m_allLabels.end(), curLabels.begin(), curLabels.end());

}

因为本代码会涉及到读取某个文件夹下的所有文件,因为OpenCV3.x版本不提供此功能,但是OpenCV2.4.x提供此功能。所以,我们将OpenCV2.4.x的这部分代码单独拉出来并建立一个头文件和源码文件,具体代码如下:

// opencv2410_inputoutput.h


#ifndef OPENCV2410_INPUT_OUTPUT_H
#define OPENCV2410_INPUT_OUTPUT_H

#include <vector>
#include <string>

class  Directory
{
public:
    static std::vector<std::string> GetListFiles(const std::string& path, const std::string & exten = "*", bool addPath = true);
    static std::vector<std::string> GetListFilesR(const std::string& path, const std::string & exten = "*", bool addPath = true);
    static std::vector<std::string> GetListFolders(const std::string& path, const std::string & exten = "*", bool addPath = true);
};


#endif // !OPENCV2410_INPUT_OUTPUT_H



// opencv2410_inputoutput.cpp

#include "opencv2410_inputoutput.h"
//#include "opencv2/contrib/contrib.hpp"
//#include <cvconfig.h>

#if defined(WIN32) || defined(_WIN32)
    #include <windows.h>
    #include <tchar.h>
#else
    #include <dirent.h>
#endif
std::vector<std::string> Directory::GetListFiles(const std::string& path, const std::string & exten, bool addPath)
{
    std::vector<std::string> list;
    list.clear();
    std::string path_f = path + "/" + exten;
#ifdef WIN32
#ifdef HAVE_WINRT
    WIN32_FIND_DATAW FindFileData;
#else
    WIN32_FIND_DATAA FindFileData;
#endif
    HANDLE hFind;

#ifdef HAVE_WINRT
    wchar_t wpath[MAX_PATH];
    size_t copied = mbstowcs(wpath, path_f.c_str(), MAX_PATH);
    CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
    hFind = FindFirstFileExW(wpath, FindExInfoStandard, &FindFileData, FindExSearchNameMatch, NULL, 0);
#else
    hFind = FindFirstFileA((LPCSTR)path_f.c_str(), &FindFileData);
#endif
    if (hFind == INVALID_HANDLE_VALUE)
    {
        return list;
    }
    else
    {
        do
        {
            if (FindFileData.dwFileAttributes == FILE_ATTRIBUTE_NORMAL ||
                FindFileData.dwFileAttributes == FILE_ATTRIBUTE_ARCHIVE ||
                FindFileData.dwFileAttributes == FILE_ATTRIBUTE_HIDDEN ||
                FindFileData.dwFileAttributes == FILE_ATTRIBUTE_SYSTEM ||
                FindFileData.dwFileAttributes == FILE_ATTRIBUTE_READONLY)
            {
                char* fname;
#ifdef HAVE_WINRT
                char fname_tmp[MAX_PATH] = { 0 };
                size_t copied = wcstombs(fname_tmp, FindFileData.cFileName, MAX_PATH);
                CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
                fname = fname_tmp;
#else
                fname = FindFileData.cFileName;
#endif
                if (addPath)
                    list.push_back(path + "/" + std::string(fname));
                else
                    list.push_back(std::string(fname));
            }
        }
#ifdef HAVE_WINRT
        while (FindNextFileW(hFind, &FindFileData));
#else
        while (FindNextFileA(hFind, &FindFileData));
#endif
        FindClose(hFind);
    }
#else
    (void)addPath;
    DIR *dp;
    struct dirent *dirp;
    if ((dp = opendir(path.c_str())) == NULL)
    {
        return list;
    }

    while ((dirp = readdir(dp)) != NULL)
    {
        if (dirp->d_type == DT_REG)
        {
            if (exten.compare("*") == 0)
                list.push_back(static_cast<std::string>(dirp->d_name));
            else
                if (std::string(dirp->d_name).find(exten) != std::string::npos)
                    list.push_back(static_cast<std::string>(dirp->d_name));
        }
    }
    closedir(dp);
#endif

    return list;
}

std::vector<std::string> Directory::GetListFolders(const std::string& path, const std::string & exten, bool addPath)
{
    std::vector<std::string> list;
    std::string path_f = path + "/" + exten;
    list.clear();
#ifdef WIN32
#ifdef HAVE_WINRT
    WIN32_FIND_DATAW FindFileData;
#else
    WIN32_FIND_DATAA FindFileData;
#endif
    HANDLE hFind;

#ifdef HAVE_WINRT
    wchar_t wpath[MAX_PATH];
    size_t copied = mbstowcs(wpath, path_f.c_str(), path_f.size());
    CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));

    hFind = FindFirstFileExW(wpath, FindExInfoStandard, &FindFileData, FindExSearchNameMatch, NULL, 0);
#else
    hFind = FindFirstFileA((LPCSTR)path_f.c_str(), &FindFileData);
#endif
    if (hFind == INVALID_HANDLE_VALUE)
    {
        return list;
    }
    else
    {
        do
        {
#ifdef HAVE_WINRT
            if (FindFileData.dwFileAttributes == FILE_ATTRIBUTE_DIRECTORY &&
                wcscmp(FindFileData.cFileName, L".") != 0 &&
                wcscmp(FindFileData.cFileName, L"..") != 0)
#else
            if (FindFileData.dwFileAttributes == FILE_ATTRIBUTE_DIRECTORY &&
                strcmp(FindFileData.cFileName, ".") != 0 &&
                strcmp(FindFileData.cFileName, "..") != 0)
#endif
            {
                char* fname;
#ifdef HAVE_WINRT
                char fname_tmp[MAX_PATH];
                size_t copied = wcstombs(fname_tmp, FindFileData.cFileName, MAX_PATH);
                CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
                fname = fname_tmp;
#else
                fname = FindFileData.cFileName;
#endif

                if (addPath)
                    list.push_back(path + "/" + std::string(fname));
                else
                    list.push_back(std::string(fname));
            }
        }
#ifdef HAVE_WINRT
        while (FindNextFileW(hFind, &FindFileData));
#else
        while (FindNextFileA(hFind, &FindFileData));
#endif
        FindClose(hFind);
    }

#else
    (void)addPath;
    DIR *dp;
    struct dirent *dirp;
    if ((dp = opendir(path_f.c_str())) == NULL)
    {
        return list;
    }

    while ((dirp = readdir(dp)) != NULL)
    {
        if (dirp->d_type == DT_DIR &&
            strcmp(dirp->d_name, ".") != 0 &&
            strcmp(dirp->d_name, "..") != 0)
        {
            if (exten.compare("*") == 0)
                list.push_back(static_cast<std::string>(dirp->d_name));
            else
                if (std::string(dirp->d_name).find(exten) != std::string::npos)
                    list.push_back(static_cast<std::string>(dirp->d_name));
        }
    }
    closedir(dp);
#endif

    return list;
}

std::vector<std::string> Directory::GetListFilesR(const std::string& path, const std::string & exten, bool addPath)
{
    std::vector<std::string> list = Directory::GetListFiles(path, exten, addPath);

    std::vector<std::string> dirs = Directory::GetListFolders(path, exten, addPath);

    std::vector<std::string>::const_iterator it;
    for (it = dirs.begin(); it != dirs.end(); ++it)
    {
        std::vector<std::string> cl = Directory::GetListFiles(*it, exten, addPath);
        list.insert(list.end(), cl.begin(), cl.end());
    }

    return list;
}

在代码注释中已经对代码充分的解释,所以在博客中不多说了

  • 2
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 5
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值