配置
libsvm 自行下载,opencv是 4.10版本,VS2015 ,release 64模式。
头文件
#ifndef __SPARKTRAIN_H__
#define __SPARKTRAIN_H__
#include <vector>
#include <list>
#include <string>
#include <fstream>
#include <io.h>
#include <direct.h>
#include "svm.h"
#include "opencv2\highgui.hpp"
#include "opencv2\opencv.hpp"
#include "opencv2\imgproc\imgproc.hpp"
#include "opencv2\objdetect.hpp"
#define _CRT_SECURE_NO_WARNINGS
using namespace std;
using namespace cv;
#ifndef _WIN64
#pragma comment(lib, "lib32\\opencv_world410.lib")
#else
#pragma comment(lib, "lib64\\opencv_world410.lib")
#endif
class SparkTrain {
public:
//训练图像列表,此处路径
vector<string> m_trainImageList;
//标签
vector<int> m_trainLabelList;
//训练图像列表,此处路径
vector<string> m_testImageList;
string m_trainImageFile;
string m_testImageFile;
//路径
string m_basePath;
//模型名称
string m_SVMModel;
Mat m_dataMat;
Mat m_labelMat;
//svm_model * m_svm;
//分类标签
vector<int> m_resultlable;
//置信度
vector<float>m_prob;
svm_model * m_svm;
public:
SparkTrain(void);
~SparkTrain(void);
//初始化
bool init(string svm_model);
//获取文件列表
void GetFileList(string filePath, const char* distAll, string format, int lable);
//读取文件
void readTrainFileList();
//读取测试图片
void readTestFileList();
//提取Hog特征
void processHogFeature();
//训练分类器
void trainLibSVM();
float testLibSVM(Mat src, double prob_estimates[]);
//CString GetMoudulePath()//获取dll所在路径
int getClassFlag(string strPath);
//辅助函数
//获取所有的文件名
void GetAllFiles(string path, vector<string>& files);
//获取特定格式的文件名
void GetAllFormatFiles(string path, vector<string>& files, string format);
};
#endif //__SPARKTRAIN_H__
tool.h
void timeuserlog(char* logname, char* fmt, ...);//日志打印
//获取时间
double getdoubletime();
源文件
#include "sparktrain.h"
#define _CRT_SECURE_NO_WARNINGS
SparkTrain::SparkTrain(void) {
m_dataMat = NULL;
m_labelMat = NULL;
m_svm = NULL;
}
SparkTrain::~SparkTrain(void) {
if (!m_dataMat.empty()) m_dataMat = NULL;;
if (!m_labelMat.empty()) m_labelMat = NULL;
if (m_svm) m_svm = NULL;
}
bool SparkTrain::init(string svm_model) {
//设置各种路径
m_basePath = "F:\\demo\\红外简易demo201981206\\样本\\样本\\";
m_trainImageFile = m_basePath + "train.txt";
m_testImageFile= m_basePath + "test.txt";
m_SVMModel = svm_model;
}
// 获取文件列表到当前目录的filelist.txt,以追加的方式写入文件
//第一个参数是文件路径
//第二个参数是文件名
//第三个参数是文件格式
//第四个参数是文件标签
//标签不为255
//标签与文件路径之间有一个空格
void SparkTrain::GetFileList(string filePath, const char* distAll, string format, int lable) {
vector<string> files;
//char * distAll = "filelist.txt";
GetAllFormatFiles(filePath, files, format);
ofstream ofn(distAll, ios::app);
int size = files.size();
//ofn << size << endl;
if (lable != 255)
{
for (int i = 0; i < size; i++)
{
ofn << files[i] << " " << lable << endl;
//cout << files[i] << endl;
}
}
else
{
for (int i = 0; i < size; i++)
{
ofn << files[i] << endl;
//cout << files[i] << endl;
}
}
ofn.close();
}
//读取文件
void SparkTrain::readTrainFileList() {
m_trainImageList.clear();
m_trainLabelList.clear();
ifstream readData(m_trainImageFile, ios::in);
string buffer;
int nClass = 0;
while (readData)
{
if (getline(readData, buffer))
{
if (buffer.size() > 0)
{
//标签与文件路径之间有一个空格
nClass = getClassFlag(buffer);
m_trainLabelList.push_back(nClass);
string temp(buffer, 0, buffer.size() - 2);
m_trainImageList.push_back(temp);//图像路径
}
}
}
readData.close();
}
void SparkTrain::readTestFileList() {
ifstream readData(m_testImageFile); //加载测试图片集合
string buffer;
while (readData)
{
if (getline(readData, buffer))
{
m_testImageList.push_back(buffer);//图像路径
}
}
readData.close();
}
//提取Hog特征
void SparkTrain::processHogFeature() {
//样本数目
int trainSampleNum = m_trainImageList.size();
//标志位
m_labelMat=Mat::zeros(trainSampleNum,1,CV_32FC1);
Mat src;
Mat trainImg=Mat::zeros(40, 40,CV_8UC1);//20 20
for (int i = 0; i != m_trainImageList.size(); i++)
{
src = imread((m_trainImageList[i]), 0);
if (src.empty())
{
continue;
}
resize(src, trainImg, Size(40, 40));
HOGDescriptor hog(Size(40, 40), Size(16, 16), Size(8, 8), Size(8, 8), 9);
//结果数组
vector<float> descriptors;
descriptors.resize(hog.getDescriptorSize());
//计算特征
hog.compute(trainImg, descriptors, Size(1, 1), Size(0, 0));
if (i == 0)
{
m_dataMat = Mat::zeros(trainSampleNum,descriptors.size(),CV_32FC1);
}
for (vector<float>::size_type j=0; j<descriptors.size(); j++)
{
//m_dataMat.at<float>(i, j) = descriptors[j];
float* ptr = m_dataMat.ptr<float>(i);
ptr[j] = descriptors[j];
}
m_labelMat.ptr<float>(i)[0] = m_trainLabelList[i];
}
}
void SparkTrain::trainLibSVM() {
//设置参数
svm_parameter param;
param.svm_type = C_SVC;
//param.svm_type = EPSILON_SVR;
param.kernel_type = RBF;
param.degree = 10.0;
param.gamma = 0.09;
param.coef0 = 1.0;
param.nu = 0.5;
param.cache_size = 1000;
param.C = 10.0;
param.eps = 1e-3;
param.p = 1.0;
param.nr_weight = 0;
param.shrinking = 1;
param.probability = 1;//后面添加,Release训练时需放开,否则SVM置信度为0
//svm_prob读取
svm_problem svm_prob;
int sampleNum = m_dataMat.rows;
int vectorLength = m_dataMat.cols;
svm_prob.l = sampleNum;
svm_prob.y = new double[sampleNum];
for (int i = 0; i < sampleNum; i++)
{
float *ptr = m_labelMat.ptr<float>(i);
svm_prob.y[i] = ptr[0];
}
svm_prob.x = new svm_node *[sampleNum];
for (int i = 0; i < sampleNum; i++)
{
svm_node * x_space = new svm_node[vectorLength + 1];
float *ptr = m_dataMat.ptr<float>(i);
for (int j = 0; j < vectorLength; j++)
{
//svm_prob.x[i]->index = j;
//svm_prob.x[i]->value = m_dataMat.at<float>(i, j);;
x_space[j].index = j;
x_space[j].value = ptr[j];
}
x_space[vectorLength].index = -1;//注意,结束符号,一开始忘记加了
svm_prob.x[i] = x_space;
//delete[] x_space;
}
svm_model * svm_model = svm_train(&svm_prob, ¶m);
string path = m_basePath + m_SVMModel;
svm_save_model(path.c_str(), svm_model);
for (int i = 0; i < sampleNum; i++)
{
delete[] svm_prob.x[i];
}
//delete x_space;
delete svm_prob.y;
svm_free_model_content(svm_model);
}
//第一个参数是指针或者引用会出错,不知道为什么
//第二个参数为数组的指针,大小为m分类的m的大小
float SparkTrain::testLibSVM(Mat src, double prob_estimates[]) {
if (src.empty())
return -1;
Mat tempImage = Mat::zeros(40, 40, CV_8UC1);//20 20;
resize(src, tempImage, Size(40, 40));
#ifdef _DEBUG
cvShowImage("testLibSVM", tempImage);
#endif
HOGDescriptor hog(Size(40, 40), Size(16, 16), Size(8, 8), Size(8, 8), 9);
//结果数组
vector<float> descriptors;
descriptors.resize(hog.getDescriptorSize());
//计算特征
hog.compute(tempImage, descriptors, Size(1, 1), Size(0, 0));
svm_node * inputVector = new svm_node[descriptors.size() + 1];
int n = 0;
for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
{
inputVector[n].index = n;
inputVector[n].value = *iter;
n++;
}
inputVector[n].index = -1;
//string path = m_basePath + m_SVMModel;
//svm_model * svm = svm_load_model(path.c_str());
int resultLabel = svm_predict_probability(m_svm, inputVector, &prob_estimates);//分类结果
delete[] inputVector;
//svm_free_model_content(svm);
return resultLabel;
}
//得到标志位
int SparkTrain::getClassFlag(string strPath) {
int len = strPath.size();
char drt = strPath[len - 1];
int temp = drt - '0';
return temp;
}
//获取所有的文件名
void SparkTrain::GetAllFiles(string path, vector<string>& files) {
long long hFile = 0;
//文件信息
struct _finddata_t fileinfo;
string p;
if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
{
do
{
if ((fileinfo.attrib & _A_SUBDIR))
{
if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
{
files.push_back(p.assign(path).append("\\").append(fileinfo.name));
GetAllFiles(p.assign(path).append("\\").append(fileinfo.name), files);
}
}
else
{
files.push_back(p.assign(path).append("\\").append(fileinfo.name));
}
} while (_findnext(hFile, &fileinfo) == 0);
_findclose(hFile);
}
}
//获取特定格式的文件名
void SparkTrain::GetAllFormatFiles(string path, vector<string>& files, string format) {
//文件句柄
long long hFile = 0;
//文件信息
struct _finddata_t fileinfo;
string p;
if ((hFile = _findfirst(p.assign(path).append("\\*" + format).c_str(), &fileinfo)) != -1)
{
do
{
if ((fileinfo.attrib & _A_SUBDIR))
{
if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
{
//files.push_back(p.assign(path).append("\\").append(fileinfo.name) );
GetAllFormatFiles(p.assign(path).append("\\").append(fileinfo.name), files, format);
}
}
else
{
files.push_back(p.assign(path).append("\\").append(fileinfo.name));
}
} while (_findnext(hFile, &fileinfo) == 0);
_findclose(hFile);
}
}
tool.cpp
#include <string>
#include <iostream>
#include <windows.h>
void timeuserlog(char* logname, char* fmt, ...)//日志打印
{
char info[1024];
va_list args;
va_start(args, fmt);
vsprintf(info, fmt, args);
va_end(args);
#ifdef _WIN32
char szTime[100];
SYSTEMTIME now_time;
GetLocalTime(&now_time);
sprintf_s(szTime, "[%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d %3.3d] ",
now_time.wYear, now_time.wMonth, now_time.wDay,
now_time.wHour, now_time.wMinute, now_time.wSecond, now_time.wMilliseconds);
char filename[100];
sprintf_s(filename, "d:\\Log\\%s_time_%d_%d_%d.txt", logname, now_time.wYear, now_time.wMonth, now_time.wDay);
FILE * fp = fopen(filename, "a");
if (fp)
{
fwrite(szTime, 1, strlen(szTime), fp);
fwrite(info, 1, strlen(info), fp);
fwrite("\n", 1, 1, fp);
fclose(fp);
}
#else
char filename[100];
sprintf(filename, "/Log/%s_time.txt", logname);
FILE * fp = fopen(filename, "a");
if (fp)
{
fwrite(info, 1, strlen(info), fp);
fwrite("\n", 1, 1, fp);
fclose(fp);
}
#endif
}
//获取时间
double getdoubletime()
{
LARGE_INTEGER t, f;
QueryPerformanceCounter(&t);
QueryPerformanceFrequency(&f);
return t.QuadPart*1.0 / f.QuadPart;
}
测试文件
#include "sparktrain.h"
#include "tools.h"
#include <string>
#define _CRT_SECURE_NO_WARNINGS
using namespace std;
int main(int argc, char** argv)
{
SparkTrain a;
//初始化
a.init("spark.model");
//将正样本写入txt,位置在当前目录下,标志位为2,标志位不能为255
//a.GetFileList("F:\\demo\\红外简易demo201981206\\样本\\样本\\1", "F:\\demo\\红外简易demo201981206\\样本\\样本\\train.txt", ".jpg", 1);
将负样本追加写入txt,位置在当前目录下,标志位为1,标志位不能为255
//a.GetFileList("F:\\demo\\红外简易demo201981206\\样本\\样本\\0", "F:\\demo\\红外简易demo201981206\\样本\\样本\\train.txt", ".jpg", 0);
读取文件
//a.readTrainFileList();
提取HOG特征
//a.processHogFeature();
训练分类器
//a.trainLibSVM();
将测试正样本写入txt,255代表标志位为空
//a.GetFileList("F:\\demo\\红外简易demo201981206\\样本\\样本\\0", a.m_testImageFile.c_str(), ".jpg", 255);
//读取测试图片
a.readTestFileList();
string path = a.m_basePath + a.m_SVMModel;
a.m_svm = svm_load_model(path.c_str());
for (int i = 0; i != a.m_testImageList.size(); i++)
{
Mat src;
src = imread((a.m_testImageList[i]).c_str(), 0);
if (src.empty())
{
continue;
}
double starttime = getdoubletime();
char info[1024];
double prob[2] = {0};
float temp = a.testLibSVM(src,prob);
double endtime = getdoubletime();
sprintf(info, "帧序号为%d 耗时%f\n", i, (endtime - starttime) * 1000);
//timeuserlog("testsvm", info);
cout << info<< endl;
//保存置信度
a.m_prob.push_back(prob);
//保存分类结果标签
a.m_resultlable.push_back(temp);
}
svm_free_model_content(a.m_svm);
system("pause");
return 0;
}