数据下载:https://download.csdn.net/download/qq_34510308/11049610
代码如下
#include <stdio.h>
#include <time.h>
#include <opencv2/opencv.hpp>
#include <opencv/cv.h>
#include <iostream>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
#include <io.h>
using namespace std;
using namespace cv;
void getAllFiles(string path, vector<string>& files);
void get_1(Mat& trainingImages, vector<int>& trainingLabels);
void get_0(Mat& trainingImages, vector<int>& trainingLabels);
//void get_g(Mat& trainingImages, vector<int>& trainingLabels);
int main()
{
//获取训练数据
Mat classes;
Mat trainingData;
Mat trainingImages;
vector<int> trainingLabels;
get_1(trainingImages, trainingLabels);
get_0(trainingImages, trainingLabels);
Mat(trainingImages).copyTo(trainingData);
trainingData.convertTo(trainingData, CV_32FC1);
Mat(trainingLabels).copyTo(classes);
//配置SVM训练器参数
cv::Ptr<cv::ml::SVM> svm = cv::ml::SVM::create();
svm->setType(cv::ml::SVM::C_SVC);
svm->setKernel(cv::ml::SVM::LINEAR);
//svm->setDegree(0);
svm->setGamma(0.01);
svm->setC(10.0);
//svm->setCoef0(0);
//svm->setNu(0);
//svm->setP(0);
svm->setTermCriteria(cvTermCriteria(CV_TERMCRIT_ITER, 1000, 0.01));
//训练
svm->train(trainingData, cv::ml::SampleTypes::ROW_SAMPLE, classes);
//保存模型
svm->save("svm.xml");
cout << "训练好了!!!" << endl;
//预测图像
Mat sample = imread("D:/data/1/100.jpg");
sample = sample.reshape(1, 1);
sample.convertTo(sample, CV_32FC1);
Ptr<cv::ml::SVM> svm1 = cv::ml::SVM::StatModel::load<cv::ml::SVM>("svm.xml");
float res = svm1->predict(sample);
cout << res << endl;
getchar();
return 0;
}
//遍历文件夹 2019-03-13
void getAllFiles(string path, vector<string>& files)
{
//文件句柄
__int64 hFile = 0;
//文件信息
struct __finddata64_t fileinfo; //很少用的文件信息读取结构
string p; //string类很有意思的一个赋值函数:assign(),有很多重载版本
if ((hFile = _findfirst64(p.assign(path).append("/*.jpg").c_str(), &fileinfo)) == -1)
{
cout << "No file is found\n" << endl;
}
else
{
do
{
files.push_back(p.assign(path).append("/").append(fileinfo.name));
} while (_findnext64(hFile, &fileinfo) == 0); //寻找下一个,成功返回0,否则-1
_findclose(hFile);
}
}
void get_1(Mat& trainingImages, vector<int>& trainingLabels)
{
const char * filePath = "D:\\data\\1";
vector<string> files;
getAllFiles(filePath, files);
int number = files.size();
for (int i = 0; i < number; i++)
{
//cout << files[i] << endl;
Mat SrcImage = imread(files[i].c_str());
SrcImage = SrcImage.reshape(1, 1);
trainingImages.push_back(SrcImage);
trainingLabels.push_back(1);
}
}
void get_0(Mat& trainingImages, vector<int>& trainingLabels)
{
const char * filePath = "D:\\data\\0";
vector<string> files;
getAllFiles(filePath, files);
int number = files.size();
for (int i = 0; i < number; i++)
{
Mat SrcImage = imread(files[i].c_str());
SrcImage = SrcImage.reshape(1, 1);
trainingImages.push_back(SrcImage);
trainingLabels.push_back(0);
}
}