SVM
是一种二分类算法,当然,也可以去做分类。OpenCV
的SVM
已经提供了多分类,不需要我们手动的去写代码。我们可以将SVM
用在数字和字符的分类上面, 比如我们已经获得一副图片,去判断这个字符到底是数字0-9
中的哪一个?根据SVM
的分类,我们可以准确地知道这副图片属于哪一类。本博客将提供完整的测试代码以及数据下载。
由于涉及到机器学习以及样本的个数不同,不同的硬件将耗费的时间也不同。下面列出我试验的电脑硬件配置
CPU i7-8750H
RAM 8GB
硬盘 128G SSD
显卡 GTX 1050Ti
我在训练过程中,CPU
使用率大约是20%,耗费时间大约是30s
。如果是i5
的计算机,耗费时间将会很长,请耐心等待。
数据集来自以下的csdn下载,感谢这位朋友提供的数据以便于个人的测试。
为了方便测试,我们只提供了数字0-9
,字母GTWX
分类,也就是说,随便拿一张图片,将其分其14类中的某一类去。
由于这位朋友的数据在每一类中的图片个数不一样,所以,我每一类中抽取200张图片作为测试,剩下的图片全部作为样本。
我们先上个结果吧,分别显示了训练阶段的每一类的训练样本个数,并显示了测试集的准确率。从结果可知,准确率很难达到100%
。后期,我们将提供所有的工程代码下载地址。
// main.cpp
#include <iostream>
#include "opencv2/opencv.hpp"
int main()
{
void test_CV3CharRecog();
double start = cv::getTickCount();
test_CV3CharRecog();
double end = cv::getTickCount();
double time = (end - start) / cv::getTickFrequency();
std::cout << "time = " << time << std::endl;
system("pause");
return 0;
}
上面是主函数,接下来提供的是主要函数调用
//test_cv3_char_recog.cpp
/**
* 训练阶段:
* [1] 批量导入图片.
* 本Demo自带一个批量导入图片的函数, 也可以采用Qt等接口
* 一个目录下仅一个类的图片,比如数字2,不能有其他类的图片.
* 生成的数据为vector<Mat>
*
* [2] 将每一类的图片和标签添加到训练数据中。
* void CV3SVMOperate::AddTrainSample(std::vector<cv::Mat>& imgs, int labelValue);
* 其中imgs为[1]的生成数据,而labelValue为CharsLabelValue的静态成员变量。
* 比如CharsLabelValue::NUM_ZERO,严禁使用数字代替
*
* [3] 设置SVM的参数和HOG参数,如果构造函数中已经指明,则不需要设置
*
* [4] 训练.
* void CV3SVMOperate::train();
* void CV3SVMOperate::train(const std::string& savePath);
* 上一个不保存模型文件,仅仅用来测试使用.
* 下面一个保存模型文件.
*
*
*
* 预测阶段:
* [1] 通过构造函数指明模型路径
* CV3SVMOperate::CV3SVMOperate(const std::string& modelPath);
*
* [2] 设置HOG参数
* void setHOG(cv::Size winSize, cv::Size blockSize, cv::Size blockStride, cv::Size cellSize, int nbins);
*
* [3] 预测
* int CV3SVMOperate::predict(cv::Mat& img);
* std::string CV3SVMOperate::predictCharName(cv::Mat& img);
* 第一个预测函数仅仅是返回预测的标签值
* 第二个预测函数返回标签值所对应的字符
*
*/
#include <iostream>
#include "opencv2/opencv.hpp"
#include "chars_info.h"
#include "opencv_3_svm_operate.h"
#include "opencv2410_inputoutput.h"
//#include "../ReadFilePaths/opencv2410_inputoutput.h"
/*
* 将dir路径下面所有的jpg图片不加任何处理的全部导入到imgsMat中, 并返回导入图片的个数
* dir: 目录的路径
* imgsMat: 存储所有图片的Mat
* return: 当前目录的图片个数
*/
static int LoadImgs(const std::string& dir, std::vector<cv::Mat>& imgsMat)
{
imgsMat.clear();
//Directory d;
std::vector<std::string> filePaths = Directory::GetListFilesR(dir, "*.jpg", true);
for (std::string str : filePaths)
{
cv::Mat img = cv::imread(str, cv::IMREAD_COLOR);
cv::Mat tmp;
cv::resize(img, tmp, cv::Size(16, 25));
imgsMat.push_back(tmp);
}
//cv::waitKey(1);
return imgsMat.size();
}
/*
* 将一个路径下的所有图片以及标签值添加到训练数据中
* cv3SVM: CV3SVMOperate对象, 内含有训练数据的成员变量
* path: 某一类图片的目录路径
* labelValue: 某一类图片的标签值. 必须是CharsLabelValue的静态成员.
* 如果不是所列出的某一类,可以手动添加,但是标签值不能重复
*/
static void AddOneDir(CV3SVMOperate& cv3SVM, std::string path, int labelValue)
{
std::vector<cv::Mat> imgsMat;
LoadImgs(path, imgsMat);
cv3SVM.addTrainSample(imgsMat, labelValue);
std::cout << "类别:" << CharsLabelValue::LabelValueToCharName(labelValue);
std::cout << "\t样本个数:" << imgsMat.size() << std::endl;
}
/*
* 将测试集中的所有图片进行预测,与正确的符号名称进行比对,并比较信息
* imgs: 某一类目录下的所有图片
* cv3SVM: CV3SVMOperate对象
* charName: 正确的字符名称
*/
static void CompareResult(std::vector<cv::Mat>& imgs, CV3SVMOperate& cv3SVM, std::string charName)
{
int cnt = 0;
for (int i = 0; i < imgs.size(); i++)
{
cv::Mat img = imgs.at(i);
std::string result = cv3SVM.predictCharName(img);
if (result == charName)
{
cnt++;
}
}
double ratio = double(cnt) / double(imgs.size());
std::cout << "字符名称:";
std::cout << charName << " ";
std::cout << "测试样本数目为:" << imgs.size();
std::cout << "\t成功个数:" << cnt;
std::cout << "\t准确率为:" << ratio << std::endl;
}
static void test_OneDr(std::string& imgDir, CV3SVMOperate& cv3SVM, std::string result)
{
std::vector<cv::Mat> imgsMat;
LoadImgs(imgDir, imgsMat);
CompareResult(imgsMat, cv3SVM, result);
}
static void test_trainHOG()
{
//std::string prefix = "D:\\workspace\\vs2015_workspace\\AlgsBaseOpenCV\\Resources\\CharRecDemoImage\\trainData\\";
std::string prefix = "../Resources/CharRecDemoImage/trainData/";
std::string modelPath = "../Resources/b.xml";
std::string num0 = prefix + "0";
std::string num1 = prefix + "1";
std::string num2 = prefix + "2";
std::string num3 = prefix + "3";
std::string num4 = prefix + "4";
std::string num5 = prefix + "5";
std::string num6 = prefix + "6";
std::string num7 = prefix + "7";
std::string num8 = prefix + "8";
std::string num9 = prefix + "9";
std::string alphaCapG = prefix + "G";
std::string alphaCapT = prefix + "T";
std::string alphaCapW = prefix + "W";
std::string alphaCapX = prefix + "X";
CV3SVMOperate cv3SVM;
cv3SVM.setHOG(cv::Size(16, 24)/*winSize*/, cv::Size(8, 8)/*blockSize*/, cv::Size(2, 4)/*blockStride*/, cv::Size(4, 4)/*cellSize*/, 9/*nbins*/);
cv3SVM.setSVM();
// 添加训练数据
std::cout << "添加训练数据开始!" << std::endl;
AddOneDir(cv3SVM, num0, CharsLabelValue::NUM_ZERO);
AddOneDir(cv3SVM, num1, CharsLabelValue::NUM_ONE);
AddOneDir(cv3SVM, num2, CharsLabelValue::NUM_TWO);
AddOneDir(cv3SVM, num3, CharsLabelValue::NUM_THREE);
AddOneDir(cv3SVM, num4, CharsLabelValue::NUM_FOUR);
AddOneDir(cv3SVM, num5, CharsLabelValue::NUM_FIVE);
AddOneDir(cv3SVM, num6, CharsLabelValue::NUM_SIX);
AddOneDir(cv3SVM, num7, CharsLabelValue::NUM_SEVEN);
AddOneDir(cv3SVM, num8, CharsLabelValue::NUM_EIGHT);
AddOneDir(cv3SVM, num9, CharsLabelValue::NUM_NINE);
AddOneDir(cv3SVM, alphaCapG, CharsLabelValue::CHAR_CAPTITAL_G);
AddOneDir(cv3SVM, alphaCapT, CharsLabelValue::CHAR_CAPTITAL_T);
AddOneDir(cv3SVM, alphaCapW, CharsLabelValue::CHAR_CAPTITAL_W);
AddOneDir(cv3SVM, alphaCapX, CharsLabelValue::CHAR_CAPTITAL_X);
std::cout << "添加训练数据结束!" << std::endl;
std::cout << "训练开始!" << std::endl;
cv3SVM.train(modelPath);
std::cout << "训练完成!" << std::endl;
cv::waitKey(1);
}
static void test_Predict()
{
std::string modelPath = "../Resources/b.xml";
//std::string prefix = "D:\\workspace\\vs2015_workspace\\AlgsBaseOpenCV\\Resources\\CharRecDemoImage\\testData\\";
std::string prefix = "../Resources/CharRecDemoImage/testData/";
CV3SVMOperate cv3SVM(modelPath);
cv3SVM.setHOG(cv::Size(16, 24)/*winSize*/, cv::Size(8, 8)/*blockSize*/, cv::Size(2, 4)/*blockStride*/, cv::Size(4, 4)/*cellSize*/, 9/*nbins*/);
std::string num0 = prefix + "0";
std::string num1 = prefix + "1";
std::string num2 = prefix + "2";
std::string num3 = prefix + "3";
std::string num4 = prefix + "4";
std::string num5 = prefix + "5";
std::string num6 = prefix + "6";
std::string num7 = prefix + "7";
std::string num8 = prefix + "8";
std::string num9 = prefix + "9";
std::string alphaCapG = prefix + "G";
std::string alphaCapT = prefix + "T";
std::string alphaCapW = prefix + "W";
std::string alphaCapX = prefix + "X";
test_OneDr(num0, cv3SVM, "0");
test_OneDr(num1, cv3SVM, "1");
test_OneDr(num2, cv3SVM, "2");
test_OneDr(num3, cv3SVM, "3");
test_OneDr(num4, cv3SVM, "4");
test_OneDr(num5, cv3SVM, "5");
test_OneDr(num6, cv3SVM, "6");
test_OneDr(num7, cv3SVM, "7");
test_OneDr(num8, cv3SVM, "8");
test_OneDr(num9, cv3SVM, "9");
test_OneDr(alphaCapG, cv3SVM, "G");
test_OneDr(alphaCapT, cv3SVM, "T");
test_OneDr(alphaCapW, cv3SVM, "W");
test_OneDr(alphaCapX, cv3SVM, "X");
}
void test_CV3CharRecog()
{
void test_Predict();
void test_trainHOG();
std::cout << "训练阶段!" << std::endl;
test_trainHOG(); // 训练
std::cout << "测试训练模型!" << std::endl;
std::cout << "---------------------------" << std::endl;
std::cout << "测试开始!" << std::endl;
test_Predict(); // 测试
std::cout << "测试结束!" << std::endl;
cv::waitKey(1);
}
接下来分别提供的是主要头文件:
// chars_info.h
#ifndef CHARS_INFO_H
#define CHARS_INFO_H
#include <string>
class CharsLabelValue
{
public:
static std::string LabelValueToCharName(int v);
public:
const static int NUM_ZERO;
const static int NUM_ONE;
const static int NUM_TWO;
const static int NUM_THREE;
const static int NUM_FOUR;
const static int NUM_FIVE;
const static int NUM_SIX;
const static int NUM_SEVEN;
const static int NUM_EIGHT;
const static int NUM_NINE;
const static int CHAR_CAPTITAL_A; // 大写字母
const static int CHAR_CAPTITAL_B;
const static int CHAR_CAPTITAL_C;
const static int CHAR_CAPTITAL_D;
const static int CHAR_CAPTITAL_E;
const static int CHAR_CAPTITAL_F;
const static int CHAR_CAPTITAL_G;
const static int CHAR_CAPTITAL_H;
const static int CHAR_CAPTITAL_I;
const static int CHAR_CAPTITAL_J;
const static int CHAR_CAPTITAL_K;
const static int CHAR_CAPTITAL_L;
const static int CHAR_CAPTITAL_M;
const static int CHAR_CAPTITAL_N;
const static int CHAR_CAPTITAL_O;
const static int CHAR_CAPTITAL_P;
const static int CHAR_CAPTITAL_Q;
const static int CHAR_CAPTITAL_R;
const static int CHAR_CAPTITAL_S;
const static int CHAR_CAPTITAL_T;
const static int CHAR_CAPTITAL_U;
const static int CHAR_CAPTITAL_V;
const static int CHAR_CAPTITAL_W;
const static int CHAR_CAPTITAL_X;
const static int CHAR_CAPTITAL_Y;
const static int CHAR_CAPTITAL_Z;
// 小写字母
const static int CHAR_LOWERCASE_A;
const static int CHAR_LOWERCASE_B;
const static int CHAR_LOWERCASE_C;
const static int CHAR_LOWERCASE_D;
const static int CHAR_LOWERCASE_E;
const static int CHAR_LOWERCASE_F;
const static int CHAR_LOWERCASE_G;
const static int CHAR_LOWERCASE_H;
const static int CHAR_LOWERCASE_I;
const static int CHAR_LOWERCASE_J;
const static int CHAR_LOWERCASE_K;
const static int CHAR_LOWERCASE_L;
const static int CHAR_LOWERCASE_M;
const static int CHAR_LOWERCASE_N;
const static int CHAR_LOWERCASE_O;
const static int CHAR_LOWERCASE_P;
const static int CHAR_LOWERCASE_Q;
const static int CHAR_LOWERCASE_R;
const static int CHAR_LOWERCASE_S;
const static int CHAR_LOWERCASE_T;
const static int CHAR_LOWERCASE_U;
const static int CHAR_LOWERCASE_V;
const static int CHAR_LOWERCASE_W;
const static int CHAR_LOWERCASE_X;
const static int CHAR_LOWERCASE_Y;
const static int CHAR_LOWERCASE_Z;
private:
CharsLabelValue();
};
#endif // !CHARS_INFO_H
// opencv_3_svm_operate.h
#ifndef OPENCV_3_SVM_OPERATE_H
#define OPENCV_3_SVM_OPERATE_H
#include <string>
#include <vector>
//#include "opencv2/opencv.hpp"
#include "opencv2/core/core.hpp"
#include "opencv2/ml/ml.hpp"
#include "opencv2/objdetect/objdetect.hpp"
class CV3SVMOperate
{
private:
cv::Ptr<cv::ml::SVM> m_svm; // SVM模型
cv::HOGDescriptor m_hog; // HOG特征
std::vector<int> m_allLabels; // 训练样本所对应的标签
std::vector<cv::Mat> m_allGrads; // 所有图片的梯度矩阵
public:
CV3SVMOperate();
/*
* 构造函数
* modelPath: 训练后的模型文件路径
*/
CV3SVMOperate(const std::string& modelPath);
/*
* 所有参数必须符合以下条件
* (winSize.width - blockSize.width) % blockStride.width == 0
* (winSize.height - blockSize.height) % blockStride.height == 0
* (blockSize.width % cellSize.width) == 0
* (blockSize.height % cellSize.height) == 0
*/
void setHOG(cv::Size winSize, cv::Size blockSize, cv::Size blockStride, cv::Size cellSize, int nbins);
/*
* 设置SVM的属性
*/
void setSVM();
/*
* 训练;
* 不保存模型文件
*/
void train();
/*
* 训练;
* 保存模型文件
* savePath: svm训练后的模型文件路径
*/
void train(const std::string& savePath);
/*
* 根据模型,判定输入图片的类别,类别用int数值表示
* img: 输入图片
* return:图片所属的类; 仅仅是数字
*/
int predict(cv::Mat& img);
/*
* 返回所预测的字符名称
* img: 输入的原始图片
* return: string类型的字符
*/
std::string predictCharName(cv::Mat& img);
/*
* 将某一类所有的图片以及所属的标签添加到训练数据中去
* imgs: 所有的Mat类型图片
* labelValue: imgs所属的标签值. 标签值参考CharsLabelValue类里面的静态变量
*/
void addTrainSample(std::vector<cv::Mat>& imgs, int labelValue);
};
#endif // !OPENCV_3_SVM_OPERATE_H
针对上述的头文件具体实现:
// chars_info.cpp
#include "chars_info.h"
// 数字初始化
const int CharsLabelValue::NUM_ZERO = 0;
const int CharsLabelValue::NUM_ONE = 1;
const int CharsLabelValue::NUM_TWO = 2;
const int CharsLabelValue::NUM_THREE = 3;
const int CharsLabelValue::NUM_FOUR = 4;
const int CharsLabelValue::NUM_FIVE = 5;
const int CharsLabelValue::NUM_SIX = 6;
const int CharsLabelValue::NUM_SEVEN = 7;
const int CharsLabelValue::NUM_EIGHT = 8;
const int CharsLabelValue::NUM_NINE = 9;
// 大写字母标签值
const int CharsLabelValue::CHAR_CAPTITAL_A = 10;
const int CharsLabelValue::CHAR_CAPTITAL_B = 11;
const int CharsLabelValue::CHAR_CAPTITAL_C = 12;
const int CharsLabelValue::CHAR_CAPTITAL_D = 13;
const int CharsLabelValue::CHAR_CAPTITAL_E = 14;
const int CharsLabelValue::CHAR_CAPTITAL_F = 15;
const int CharsLabelValue::CHAR_CAPTITAL_G = 16;
const int CharsLabelValue::CHAR_CAPTITAL_H = 17;
const int CharsLabelValue::CHAR_CAPTITAL_I = 18;
const int CharsLabelValue::CHAR_CAPTITAL_J = 19;
const int CharsLabelValue::CHAR_CAPTITAL_K = 20;
const int CharsLabelValue::CHAR_CAPTITAL_L = 21;
const int CharsLabelValue::CHAR_CAPTITAL_M = 22;
const int CharsLabelValue::CHAR_CAPTITAL_N = 23;
const int CharsLabelValue::CHAR_CAPTITAL_O = 24;
const int CharsLabelValue::CHAR_CAPTITAL_P = 25;
const int CharsLabelValue::CHAR_CAPTITAL_Q = 26;
const int CharsLabelValue::CHAR_CAPTITAL_R = 27;
const int CharsLabelValue::CHAR_CAPTITAL_S = 28;
const int CharsLabelValue::CHAR_CAPTITAL_T = 29;
const int CharsLabelValue::CHAR_CAPTITAL_U = 30;
const int CharsLabelValue::CHAR_CAPTITAL_V = 31;
const int CharsLabelValue::CHAR_CAPTITAL_W = 32;
const int CharsLabelValue::CHAR_CAPTITAL_X = 33;
const int CharsLabelValue::CHAR_CAPTITAL_Y = 34;
const int CharsLabelValue::CHAR_CAPTITAL_Z = 35;
// 小写字母
const int CharsLabelValue::CHAR_LOWERCASE_A = 36;
const int CharsLabelValue::CHAR_LOWERCASE_B = 37;
const int CharsLabelValue::CHAR_LOWERCASE_C = 38;
const int CharsLabelValue::CHAR_LOWERCASE_D = 39;
const int CharsLabelValue::CHAR_LOWERCASE_E = 40;
const int CharsLabelValue::CHAR_LOWERCASE_F = 41;
const int CharsLabelValue::CHAR_LOWERCASE_G = 42;
const int CharsLabelValue::CHAR_LOWERCASE_H = 43;
const int CharsLabelValue::CHAR_LOWERCASE_I = 44;
const int CharsLabelValue::CHAR_LOWERCASE_J = 45;
const int CharsLabelValue::CHAR_LOWERCASE_K = 46;
const int CharsLabelValue::CHAR_LOWERCASE_L = 47;
const int CharsLabelValue::CHAR_LOWERCASE_M = 48;
const int CharsLabelValue::CHAR_LOWERCASE_N = 49;
const int CharsLabelValue::CHAR_LOWERCASE_O = 50;
const int CharsLabelValue::CHAR_LOWERCASE_P = 51;
const int CharsLabelValue::CHAR_LOWERCASE_Q = 52;
const int CharsLabelValue::CHAR_LOWERCASE_R = 53;
const int CharsLabelValue::CHAR_LOWERCASE_S = 54;
const int CharsLabelValue::CHAR_LOWERCASE_T = 55;
const int CharsLabelValue::CHAR_LOWERCASE_U = 56;
const int CharsLabelValue::CHAR_LOWERCASE_V = 57;
const int CharsLabelValue::CHAR_LOWERCASE_W = 58;
const int CharsLabelValue::CHAR_LOWERCASE_X = 59;
const int CharsLabelValue::CHAR_LOWERCASE_Y = 60;
const int CharsLabelValue::CHAR_LOWERCASE_Z = 61;
/*
* 返回预测值所对应的字符名称
* v: 值
* return: 字符名称
*/
std::string CharsLabelValue::LabelValueToCharName(int v)
{
/*
* 对于不需要判断的字符,可以注释掉判断代码
* 或者是对需要分类的字符调整一下顺序,避免速度问题
*/
// 判断是哪个数字
if (v == CharsLabelValue::NUM_ZERO) return "0";
if (v == CharsLabelValue::NUM_ONE) return "1";
if (v == CharsLabelValue::NUM_TWO) return "2";
if (v == CharsLabelValue::NUM_THREE) return "3";
if (v == CharsLabelValue::NUM_FOUR) return "4";
if (v == CharsLabelValue::NUM_FIVE) return "5";
if (v == CharsLabelValue::NUM_SIX) return "6";
if (v == CharsLabelValue::NUM_SEVEN) return "7";
if (v == CharsLabelValue::NUM_EIGHT) return "8";
if (v == CharsLabelValue::NUM_NINE) return "9";
// 判断是哪个大写字符
if (v == CharsLabelValue::CHAR_CAPTITAL_A) return "A";
if (v == CharsLabelValue::CHAR_CAPTITAL_B) return "B";
if (v == CharsLabelValue::CHAR_CAPTITAL_C) return "C";
if (v == CharsLabelValue::CHAR_CAPTITAL_D) return "D";
if (v == CharsLabelValue::CHAR_CAPTITAL_E) return "E";
if (v == CharsLabelValue::CHAR_CAPTITAL_F) return "F";
if (v == CharsLabelValue::CHAR_CAPTITAL_G) return "G";
if (v == CharsLabelValue::CHAR_CAPTITAL_H) return "H";
if (v == CharsLabelValue::CHAR_CAPTITAL_I) return "I";
if (v == CharsLabelValue::CHAR_CAPTITAL_J) return "J";
if (v == CharsLabelValue::CHAR_CAPTITAL_K) return "K";
if (v == CharsLabelValue::CHAR_CAPTITAL_L) return "L";
if (v == CharsLabelValue::CHAR_CAPTITAL_M) return "M";
if (v == CharsLabelValue::CHAR_CAPTITAL_N) return "N";
if (v == CharsLabelValue::CHAR_CAPTITAL_O) return "O";
if (v == CharsLabelValue::CHAR_CAPTITAL_P) return "P";
if (v == CharsLabelValue::CHAR_CAPTITAL_Q) return "Q";
if (v == CharsLabelValue::CHAR_CAPTITAL_R) return "R";
if (v == CharsLabelValue::CHAR_CAPTITAL_S) return "S";
if (v == CharsLabelValue::CHAR_CAPTITAL_T) return "T";
if (v == CharsLabelValue::CHAR_CAPTITAL_U) return "U";
if (v == CharsLabelValue::CHAR_CAPTITAL_V) return "V";
if (v == CharsLabelValue::CHAR_CAPTITAL_W) return "W";
if (v == CharsLabelValue::CHAR_CAPTITAL_X) return "X";
if (v == CharsLabelValue::CHAR_CAPTITAL_Y) return "Y";
if (v == CharsLabelValue::CHAR_CAPTITAL_Z) return "Z";
// 判断小写字符
if (v == CharsLabelValue::CHAR_LOWERCASE_A) return "a";
if (v == CharsLabelValue::CHAR_LOWERCASE_B) return "b";
if (v == CharsLabelValue::CHAR_LOWERCASE_C) return "c";
if (v == CharsLabelValue::CHAR_LOWERCASE_D) return "d";
if (v == CharsLabelValue::CHAR_LOWERCASE_E) return "e";
if (v == CharsLabelValue::CHAR_LOWERCASE_F) return "f";
if (v == CharsLabelValue::CHAR_LOWERCASE_G) return "g";
if (v == CharsLabelValue::CHAR_LOWERCASE_H) return "h";
if (v == CharsLabelValue::CHAR_LOWERCASE_I) return "i";
if (v == CharsLabelValue::CHAR_LOWERCASE_J) return "j";
if (v == CharsLabelValue::CHAR_LOWERCASE_K) return "k";
if (v == CharsLabelValue::CHAR_LOWERCASE_L) return "l";
if (v == CharsLabelValue::CHAR_LOWERCASE_M) return "m";
if (v == CharsLabelValue::CHAR_LOWERCASE_N) return "n";
if (v == CharsLabelValue::CHAR_LOWERCASE_O) return "o";
if (v == CharsLabelValue::CHAR_LOWERCASE_P) return "p";
if (v == CharsLabelValue::CHAR_LOWERCASE_Q) return "q";
if (v == CharsLabelValue::CHAR_LOWERCASE_R) return "r";
if (v == CharsLabelValue::CHAR_LOWERCASE_S) return "s";
if (v == CharsLabelValue::CHAR_LOWERCASE_T) return "t";
if (v == CharsLabelValue::CHAR_LOWERCASE_U) return "u";
if (v == CharsLabelValue::CHAR_LOWERCASE_V) return "v";
if (v == CharsLabelValue::CHAR_LOWERCASE_W) return "w";
if (v == CharsLabelValue::CHAR_LOWERCASE_X) return "x";
if (v == CharsLabelValue::CHAR_LOWERCASE_Y) return "y";
if (v == CharsLabelValue::CHAR_LOWERCASE_Z) return "z";
// 不属于任何以上条件的,均返回空
return "";
}
#include <string>
#include <vector>
#include "opencv_3_svm_operate.h"
#include "chars_info.h"
/*
* gradientLst: 存储所有的梯度容器
* trainData: 得到的训练矩阵
*/
static void ConvertMLData(std::vector<cv::Mat>& gradientLst, cv::Mat& trainData)
{
const int ROWS = (int)gradientLst.size();
const int COLS = (int)std::max(gradientLst[0].cols, gradientLst[0].rows);
cv::Mat tmp(1, COLS, CV_32FC1); // 用来存放转制的矩阵
trainData = cv::Mat(ROWS, COLS, CV_32FC1);
// 数据按照行来排列
for (int i = 0; i < gradientLst.size(); i++)
{
cv::Mat m = gradientLst.at(i);
if (m.cols == 1)
{
cv::transpose(m, tmp);
tmp.copyTo(trainData.row(i));
}
else if (m.rows == 1)
{
m.copyTo(trainData.row(i));
}
}
}
/*
* 返回所预测的字符名称
* img: 输入的原始图片
* return: string类型的字符
*/
std::string CV3SVMOperate::predictCharName(cv::Mat& img)
{
int value = this->predict(img);
return CharsLabelValue::LabelValueToCharName(value);
}
/*
* 所有参数必须符合以下条件
* (winSize.width - blockSize.width) % blockStride.width == 0
* (winSize.height - blockSize.height) % blockStride.height == 0
* (blockSize.width % cellSize.width) == 0
* (blockSize.height % cellSize.height) == 0
*/
void CV3SVMOperate::setHOG(cv::Size winSize, cv::Size blockSize, cv::Size blockStride, cv::Size cellSize, int nbins)
{
// HOG特征
m_hog = cv::HOGDescriptor(winSize, blockSize, blockStride, cellSize, nbins);
}
CV3SVMOperate::CV3SVMOperate():m_svm(cv::ml::SVM::create()), m_hog(cv::HOGDescriptor())
{
}
/*
* 构造函数
* modelPath: 训练后的模型文件路径
*/
CV3SVMOperate::CV3SVMOperate(const std::string& modelPath)
{
m_svm = cv::Algorithm::load<cv::ml::SVM>(modelPath);
}
/*
* 设置SVM的属性
*/
void CV3SVMOperate::setSVM()
{
cv::TermCriteria tc = cv::TermCriteria(cv::TermCriteria::MAX_ITER + cv::TermCriteria::EPS, 1000/*maxCount*/, FLT_EPSILON);
m_svm->setType(cv::ml::SVM::Types::C_SVC); // 设置SVM类型
m_svm->setKernel(cv::ml::SVM::KernelTypes::RBF); //设置核类型
m_svm->setDegree(10.0);
m_svm->setGamma(0.09); // 设置gamma值,默认为1.0
m_svm->setCoef0(1.0);
m_svm->setC(10.0);
m_svm->setNu(0.5);
m_svm->setP(1.0);
m_svm->setTermCriteria(tc);
}
/*
* 训练;
* 不保存模型文件
*/
void CV3SVMOperate::train()
{
// 将数据转换成标准的svm类型
cv::Mat trainData;
ConvertMLData(m_allGrads, trainData);
cv::Mat labels = cv::Mat(m_allLabels);
m_svm->train(trainData, cv::ml::ROW_SAMPLE, labels);
}
/*
* 训练;
* 保存模型文件
* savePath: svm训练后的模型文件路径
*/
void CV3SVMOperate::train(const std::string& savePath)
{
// 训练之后并保存模型文件
this->train();
m_svm->save(savePath);
}
/*
* 根据模型,判定输入图片的类别,类别用int数值表示
* img: 输入图片
* return:图片所属的类; 仅仅是数字
*/
int CV3SVMOperate::predict(cv::Mat& img)
{
// 这里的预测仅仅是标签值,并没有指定是哪一类.
std::vector<float> descriptors;
cv::Size winStride(0, 0);
m_hog.compute(img, descriptors, winStride); // 计算梯度
cv::Mat input;
cv::Mat(descriptors).copyTo(input);
cv::transpose(input, input);
// 预测
int ret = this->m_svm->predict(input);
return ret;
}
/*
* 将某一类所有的图片以及所属的标签添加到训练数据中去
* imgs: 所有的Mat类型图片
* labelValue: imgs所属的标签值. 标签值参考CharsLabelValue类里面的静态变量
*/
void CV3SVMOperate::addTrainSample(std::vector<cv::Mat>& imgs, int labelValue)
{
std::vector<cv::Mat> curAllGrads; // 存储当前所有的梯度
// imgs是预处理后的图像.现在计算梯度
for (int i = 0; i < imgs.size(); i++)
{
cv::Mat m = imgs.at(i);
std::vector<float> descriptors;
cv::Size winStride(0, 0);
m_hog.compute(m, descriptors, winStride); // 计算梯度
curAllGrads.push_back(cv::Mat(descriptors).clone());
}
// 将当前的梯度补充到前面所有的梯度中。
m_allGrads.insert(m_allGrads.end(), curAllGrads.begin(), curAllGrads.end());
// 标签也要添加
std::vector<int> curLabels(curAllGrads.size(), labelValue);
m_allLabels.insert(m_allLabels.end(), curLabels.begin(), curLabels.end());
}
因为本代码会涉及到读取某个文件夹下的所有文件,因为OpenCV3.x版本不提供此功能,但是OpenCV2.4.x提供此功能。所以,我们将OpenCV2.4.x的这部分代码单独拉出来并建立一个头文件和源码文件,具体代码如下:
// opencv2410_inputoutput.h
#ifndef OPENCV2410_INPUT_OUTPUT_H
#define OPENCV2410_INPUT_OUTPUT_H
#include <vector>
#include <string>
class Directory
{
public:
static std::vector<std::string> GetListFiles(const std::string& path, const std::string & exten = "*", bool addPath = true);
static std::vector<std::string> GetListFilesR(const std::string& path, const std::string & exten = "*", bool addPath = true);
static std::vector<std::string> GetListFolders(const std::string& path, const std::string & exten = "*", bool addPath = true);
};
#endif // !OPENCV2410_INPUT_OUTPUT_H
// opencv2410_inputoutput.cpp
#include "opencv2410_inputoutput.h"
//#include "opencv2/contrib/contrib.hpp"
//#include <cvconfig.h>
#if defined(WIN32) || defined(_WIN32)
#include <windows.h>
#include <tchar.h>
#else
#include <dirent.h>
#endif
std::vector<std::string> Directory::GetListFiles(const std::string& path, const std::string & exten, bool addPath)
{
std::vector<std::string> list;
list.clear();
std::string path_f = path + "/" + exten;
#ifdef WIN32
#ifdef HAVE_WINRT
WIN32_FIND_DATAW FindFileData;
#else
WIN32_FIND_DATAA FindFileData;
#endif
HANDLE hFind;
#ifdef HAVE_WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path_f.c_str(), MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
hFind = FindFirstFileExW(wpath, FindExInfoStandard, &FindFileData, FindExSearchNameMatch, NULL, 0);
#else
hFind = FindFirstFileA((LPCSTR)path_f.c_str(), &FindFileData);
#endif
if (hFind == INVALID_HANDLE_VALUE)
{
return list;
}
else
{
do
{
if (FindFileData.dwFileAttributes == FILE_ATTRIBUTE_NORMAL ||
FindFileData.dwFileAttributes == FILE_ATTRIBUTE_ARCHIVE ||
FindFileData.dwFileAttributes == FILE_ATTRIBUTE_HIDDEN ||
FindFileData.dwFileAttributes == FILE_ATTRIBUTE_SYSTEM ||
FindFileData.dwFileAttributes == FILE_ATTRIBUTE_READONLY)
{
char* fname;
#ifdef HAVE_WINRT
char fname_tmp[MAX_PATH] = { 0 };
size_t copied = wcstombs(fname_tmp, FindFileData.cFileName, MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
fname = fname_tmp;
#else
fname = FindFileData.cFileName;
#endif
if (addPath)
list.push_back(path + "/" + std::string(fname));
else
list.push_back(std::string(fname));
}
}
#ifdef HAVE_WINRT
while (FindNextFileW(hFind, &FindFileData));
#else
while (FindNextFileA(hFind, &FindFileData));
#endif
FindClose(hFind);
}
#else
(void)addPath;
DIR *dp;
struct dirent *dirp;
if ((dp = opendir(path.c_str())) == NULL)
{
return list;
}
while ((dirp = readdir(dp)) != NULL)
{
if (dirp->d_type == DT_REG)
{
if (exten.compare("*") == 0)
list.push_back(static_cast<std::string>(dirp->d_name));
else
if (std::string(dirp->d_name).find(exten) != std::string::npos)
list.push_back(static_cast<std::string>(dirp->d_name));
}
}
closedir(dp);
#endif
return list;
}
std::vector<std::string> Directory::GetListFolders(const std::string& path, const std::string & exten, bool addPath)
{
std::vector<std::string> list;
std::string path_f = path + "/" + exten;
list.clear();
#ifdef WIN32
#ifdef HAVE_WINRT
WIN32_FIND_DATAW FindFileData;
#else
WIN32_FIND_DATAA FindFileData;
#endif
HANDLE hFind;
#ifdef HAVE_WINRT
wchar_t wpath[MAX_PATH];
size_t copied = mbstowcs(wpath, path_f.c_str(), path_f.size());
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
hFind = FindFirstFileExW(wpath, FindExInfoStandard, &FindFileData, FindExSearchNameMatch, NULL, 0);
#else
hFind = FindFirstFileA((LPCSTR)path_f.c_str(), &FindFileData);
#endif
if (hFind == INVALID_HANDLE_VALUE)
{
return list;
}
else
{
do
{
#ifdef HAVE_WINRT
if (FindFileData.dwFileAttributes == FILE_ATTRIBUTE_DIRECTORY &&
wcscmp(FindFileData.cFileName, L".") != 0 &&
wcscmp(FindFileData.cFileName, L"..") != 0)
#else
if (FindFileData.dwFileAttributes == FILE_ATTRIBUTE_DIRECTORY &&
strcmp(FindFileData.cFileName, ".") != 0 &&
strcmp(FindFileData.cFileName, "..") != 0)
#endif
{
char* fname;
#ifdef HAVE_WINRT
char fname_tmp[MAX_PATH];
size_t copied = wcstombs(fname_tmp, FindFileData.cFileName, MAX_PATH);
CV_Assert((copied != MAX_PATH) && (copied != (size_t)-1));
fname = fname_tmp;
#else
fname = FindFileData.cFileName;
#endif
if (addPath)
list.push_back(path + "/" + std::string(fname));
else
list.push_back(std::string(fname));
}
}
#ifdef HAVE_WINRT
while (FindNextFileW(hFind, &FindFileData));
#else
while (FindNextFileA(hFind, &FindFileData));
#endif
FindClose(hFind);
}
#else
(void)addPath;
DIR *dp;
struct dirent *dirp;
if ((dp = opendir(path_f.c_str())) == NULL)
{
return list;
}
while ((dirp = readdir(dp)) != NULL)
{
if (dirp->d_type == DT_DIR &&
strcmp(dirp->d_name, ".") != 0 &&
strcmp(dirp->d_name, "..") != 0)
{
if (exten.compare("*") == 0)
list.push_back(static_cast<std::string>(dirp->d_name));
else
if (std::string(dirp->d_name).find(exten) != std::string::npos)
list.push_back(static_cast<std::string>(dirp->d_name));
}
}
closedir(dp);
#endif
return list;
}
std::vector<std::string> Directory::GetListFilesR(const std::string& path, const std::string & exten, bool addPath)
{
std::vector<std::string> list = Directory::GetListFiles(path, exten, addPath);
std::vector<std::string> dirs = Directory::GetListFolders(path, exten, addPath);
std::vector<std::string>::const_iterator it;
for (it = dirs.begin(); it != dirs.end(); ++it)
{
std::vector<std::string> cl = Directory::GetListFiles(*it, exten, addPath);
list.insert(list.end(), cl.begin(), cl.end());
}
return list;
}
在代码注释中已经对代码充分的解释,所以在博客中不多说了