libsvm与opencv 在c++中使用

配置

libsvm 自行下载,opencv是 4.10版本,VS2015 ,release 64模式。

头文件

#ifndef __SPARKTRAIN_H__
#define __SPARKTRAIN_H__

#include <vector>
#include <list>
#include <string>
#include <fstream>
#include <io.h>
#include <direct.h>

#include "svm.h"
#include "opencv2\highgui.hpp"
#include "opencv2\opencv.hpp"
#include "opencv2\imgproc\imgproc.hpp"
#include "opencv2\objdetect.hpp"

#define  _CRT_SECURE_NO_WARNINGS

using namespace std;
using namespace cv;

#ifndef _WIN64
#pragma comment(lib, "lib32\\opencv_world410.lib")
#else
#pragma comment(lib, "lib64\\opencv_world410.lib")
#endif

class SparkTrain {

public:
	//训练图像列表,此处路径
	vector<string> m_trainImageList;
	//标签
	vector<int> m_trainLabelList;   
	//训练图像列表,此处路径
	vector<string> m_testImageList;

	string m_trainImageFile;
	string m_testImageFile;
	//路径
	string m_basePath;
	//模型名称
	string m_SVMModel;
	
	Mat  m_dataMat;
	Mat  m_labelMat;
	//svm_model * m_svm;
	

	//分类标签
	vector<int> m_resultlable;
	//置信度
	vector<float>m_prob;	
	svm_model * m_svm;

public:
	SparkTrain(void);
	~SparkTrain(void);
	
	//初始化
	bool init(string svm_model);
	//获取文件列表
	void GetFileList(string filePath, const char* distAll, string format, int lable);
	//读取文件
	void readTrainFileList();
	//读取测试图片
	void readTestFileList();
	//提取Hog特征
	void processHogFeature();
	//训练分类器
	void trainLibSVM();
	
	float testLibSVM(Mat src, double prob_estimates[]);
	//CString GetMoudulePath()//获取dll所在路径
	int  getClassFlag(string strPath);
	
	//辅助函数
	//获取所有的文件名
	void GetAllFiles(string path, vector<string>& files);
	//获取特定格式的文件名
	void GetAllFormatFiles(string path, vector<string>& files, string format);

};
#endif  //__SPARKTRAIN_H__

tool.h

void timeuserlog(char* logname, char* fmt, ...);//日志打印
//获取时间
double getdoubletime();

源文件

#include "sparktrain.h"
#define  _CRT_SECURE_NO_WARNINGS

SparkTrain::SparkTrain(void) {

	m_dataMat = NULL;
	m_labelMat = NULL;
	m_svm = NULL;
}
SparkTrain::~SparkTrain(void) {

	if (!m_dataMat.empty()) m_dataMat = NULL;;
	if (!m_labelMat.empty()) m_labelMat = NULL;
	if (m_svm) m_svm = NULL;

}
bool SparkTrain::init(string svm_model) {

	//设置各种路径
	m_basePath = "F:\\demo\\红外简易demo201981206\\样本\\样本\\";
	
	m_trainImageFile = m_basePath + "train.txt";
	m_testImageFile= m_basePath + "test.txt";
	m_SVMModel = svm_model;
	
}

// 获取文件列表到当前目录的filelist.txt,以追加的方式写入文件
//第一个参数是文件路径
//第二个参数是文件名
//第三个参数是文件格式
//第四个参数是文件标签
//标签不为255
//标签与文件路径之间有一个空格
void SparkTrain::GetFileList(string filePath, const char* distAll, string format, int lable) {
	
	vector<string> files;
	//char * distAll = "filelist.txt";
	GetAllFormatFiles(filePath, files, format);
	ofstream ofn(distAll, ios::app);
	int size = files.size();
	//ofn << size << endl;
	if (lable != 255)
	{
		for (int i = 0; i < size; i++)
		{
			ofn << files[i] << " " << lable << endl;
			//cout << files[i] << endl;
		}
	}
	else
	{
		for (int i = 0; i < size; i++)
		{
			ofn << files[i] << endl;
			//cout << files[i] << endl;
		}
	}
	ofn.close();



}
//读取文件
void SparkTrain::readTrainFileList() {
	m_trainImageList.clear();
	m_trainLabelList.clear();
	ifstream readData(m_trainImageFile, ios::in);
	string buffer;
	int nClass = 0;
	while (readData)
	{
		if (getline(readData, buffer))
		{

			if (buffer.size() > 0)
			{
				//标签与文件路径之间有一个空格
				nClass = getClassFlag(buffer);
				m_trainLabelList.push_back(nClass);
				string temp(buffer, 0, buffer.size() - 2);
				m_trainImageList.push_back(temp);//图像路径  
			}
		}
	}
	readData.close();


}

void SparkTrain::readTestFileList() {

	ifstream readData(m_testImageFile);  //加载测试图片集合
	string buffer;
	while (readData)
	{
		if (getline(readData, buffer))
		{
			m_testImageList.push_back(buffer);//图像路径       
		}
	}
	readData.close();

}

//提取Hog特征
void SparkTrain::processHogFeature() {

	//样本数目
	int trainSampleNum = m_trainImageList.size();
	//标志位
	m_labelMat=Mat::zeros(trainSampleNum,1,CV_32FC1);
	
	Mat src;
	Mat trainImg=Mat::zeros(40, 40,CV_8UC1);//20 20

	for (int i = 0; i != m_trainImageList.size(); i++)
	{
		src = imread((m_trainImageList[i]), 0);
		if (src.empty())
		{
			continue;
		}
		resize(src, trainImg, Size(40, 40));
		HOGDescriptor hog(Size(40, 40), Size(16, 16), Size(8, 8), Size(8, 8), 9);
		//结果数组
		vector<float> descriptors;
		descriptors.resize(hog.getDescriptorSize());
		//计算特征
		hog.compute(trainImg, descriptors, Size(1, 1), Size(0, 0));
		if (i == 0)
		{
			m_dataMat = Mat::zeros(trainSampleNum,descriptors.size(),CV_32FC1);
		}
		for (vector<float>::size_type j=0; j<descriptors.size();  j++)
		{
			
			//m_dataMat.at<float>(i, j) = descriptors[j];
			float* ptr = m_dataMat.ptr<float>(i);
			ptr[j] = descriptors[j];
		}
		m_labelMat.ptr<float>(i)[0] = m_trainLabelList[i];
	}
}



void SparkTrain::trainLibSVM() {
	
	//设置参数
	svm_parameter param;
	param.svm_type = C_SVC;
	//param.svm_type = EPSILON_SVR;
	param.kernel_type = RBF;
	param.degree = 10.0;
	param.gamma = 0.09;
	param.coef0 = 1.0;
	param.nu = 0.5;
	param.cache_size = 1000;
	param.C = 10.0;
	param.eps = 1e-3;
	param.p = 1.0;
	param.nr_weight = 0;
	param.shrinking = 1;
	param.probability = 1;//后面添加,Release训练时需放开,否则SVM置信度为0


	 //svm_prob读取
	svm_problem svm_prob;
	int sampleNum = m_dataMat.rows;
	int vectorLength = m_dataMat.cols;

	svm_prob.l = sampleNum;
	svm_prob.y = new double[sampleNum];

	for (int i = 0; i < sampleNum; i++)
	{
		float *ptr = m_labelMat.ptr<float>(i);
		svm_prob.y[i] = ptr[0];
	}

	svm_prob.x = new  svm_node *[sampleNum];
	
	
	for (int i = 0; i < sampleNum; i++)
	{
		svm_node * x_space = new svm_node[vectorLength + 1];
		float *ptr = m_dataMat.ptr<float>(i);
		
		for (int j = 0; j < vectorLength; j++)
		{
			//svm_prob.x[i]->index = j;
			//svm_prob.x[i]->value = m_dataMat.at<float>(i, j);;
			x_space[j].index = j;
			x_space[j].value = ptr[j];
		}
		x_space[vectorLength].index = -1;//注意,结束符号,一开始忘记加了
		svm_prob.x[i] = x_space;
		//delete[] x_space;
	}

	svm_model * svm_model = svm_train(&svm_prob, &param);
	
	string path = m_basePath + m_SVMModel;
	svm_save_model(path.c_str(), svm_model);


	for (int i = 0; i < sampleNum; i++)
	{
		delete[] svm_prob.x[i];
	}
	//delete x_space;
	delete svm_prob.y;
	
	svm_free_model_content(svm_model);
}


//第一个参数是指针或者引用会出错,不知道为什么
//第二个参数为数组的指针,大小为m分类的m的大小
float SparkTrain::testLibSVM(Mat src, double prob_estimates[]) {

	if (src.empty())
		return -1;

	Mat tempImage = Mat::zeros(40, 40, CV_8UC1);//20 20;
	
	resize(src, tempImage, Size(40, 40));

#ifdef _DEBUG
	cvShowImage("testLibSVM", tempImage);
#endif		
	HOGDescriptor hog(Size(40, 40), Size(16, 16), Size(8, 8), Size(8, 8), 9);
	
	//结果数组
	vector<float> descriptors;
	descriptors.resize(hog.getDescriptorSize());
	//计算特征
	hog.compute(tempImage, descriptors, Size(1, 1), Size(0, 0));
	
	svm_node * inputVector = new svm_node[descriptors.size() + 1];
	
	int n = 0;
	for (vector<float>::iterator iter = descriptors.begin(); iter != descriptors.end(); iter++)
	{
		inputVector[n].index = n;
		inputVector[n].value = *iter;
		n++;
	}
	inputVector[n].index = -1;

	//string path = m_basePath + m_SVMModel;
	//svm_model * svm = svm_load_model(path.c_str());
	
	
	int resultLabel = svm_predict_probability(m_svm, inputVector, &prob_estimates);//分类结果
	delete[] inputVector;
	
	//svm_free_model_content(svm);
	return resultLabel;


}
//得到标志位
int  SparkTrain::getClassFlag(string strPath) {
	
	int len = strPath.size();
	char drt = strPath[len - 1];
	int temp = drt - '0';
	return temp;
}

//获取所有的文件名
void SparkTrain::GetAllFiles(string path, vector<string>& files) {
	long long  hFile = 0;
	//文件信息  
	struct _finddata_t fileinfo;
	string p;
	if ((hFile = _findfirst(p.assign(path).append("\\*").c_str(), &fileinfo)) != -1)
	{
		do
		{
			if ((fileinfo.attrib &  _A_SUBDIR))
			{
				if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
				{
					files.push_back(p.assign(path).append("\\").append(fileinfo.name));
					GetAllFiles(p.assign(path).append("\\").append(fileinfo.name), files);
				}
			}
			else
			{
				files.push_back(p.assign(path).append("\\").append(fileinfo.name));
			}

		} while (_findnext(hFile, &fileinfo) == 0);

		_findclose(hFile);
	}

}

//获取特定格式的文件名
void SparkTrain::GetAllFormatFiles(string path, vector<string>& files, string format) {

	//文件句柄  
	long long   hFile = 0;
	//文件信息  
	struct _finddata_t fileinfo;
	string p;
	if ((hFile = _findfirst(p.assign(path).append("\\*" + format).c_str(), &fileinfo)) != -1)
	{
		do
		{
			if ((fileinfo.attrib &  _A_SUBDIR))
			{
				if (strcmp(fileinfo.name, ".") != 0 && strcmp(fileinfo.name, "..") != 0)
				{
					//files.push_back(p.assign(path).append("\\").append(fileinfo.name) );
					GetAllFormatFiles(p.assign(path).append("\\").append(fileinfo.name), files, format);
				}
			}
			else
			{
				files.push_back(p.assign(path).append("\\").append(fileinfo.name));
			}
		} while (_findnext(hFile, &fileinfo) == 0);

		_findclose(hFile);
	}

}

tool.cpp

#include <string>
#include <iostream>
#include <windows.h>

void timeuserlog(char* logname, char* fmt, ...)//日志打印
{
	char info[1024];
	va_list args;
	va_start(args, fmt);
	vsprintf(info, fmt, args);
	va_end(args);

#ifdef _WIN32
	char szTime[100];
	SYSTEMTIME now_time;
	GetLocalTime(&now_time);

	sprintf_s(szTime, "[%4.4d-%2.2d-%2.2d %2.2d:%2.2d:%2.2d %3.3d] ",
		now_time.wYear, now_time.wMonth, now_time.wDay,
		now_time.wHour, now_time.wMinute, now_time.wSecond, now_time.wMilliseconds);

	char filename[100];
	sprintf_s(filename, "d:\\Log\\%s_time_%d_%d_%d.txt", logname, now_time.wYear, now_time.wMonth, now_time.wDay);
	FILE * fp = fopen(filename, "a");
	if (fp)
	{
		fwrite(szTime, 1, strlen(szTime), fp);
		fwrite(info, 1, strlen(info), fp);
		fwrite("\n", 1, 1, fp);
		fclose(fp);
	}
#else
	char filename[100];
	sprintf(filename, "/Log/%s_time.txt", logname);
	FILE * fp = fopen(filename, "a");
	if (fp)
	{
		fwrite(info, 1, strlen(info), fp);
		fwrite("\n", 1, 1, fp);
		fclose(fp);
	}
#endif	
}

//获取时间
double getdoubletime()
{
	LARGE_INTEGER t, f;
	QueryPerformanceCounter(&t);
	QueryPerformanceFrequency(&f);
	return t.QuadPart*1.0 / f.QuadPart;
}

测试文件

#include "sparktrain.h"
#include "tools.h"
#include <string>
#define  _CRT_SECURE_NO_WARNINGS


using namespace std;



int main(int argc, char** argv)
{

	SparkTrain a;
	//初始化
	a.init("spark.model");
	//将正样本写入txt,位置在当前目录下,标志位为2,标志位不能为255
	//a.GetFileList("F:\\demo\\红外简易demo201981206\\样本\\样本\\1", "F:\\demo\\红外简易demo201981206\\样本\\样本\\train.txt", ".jpg", 1);
	将负样本追加写入txt,位置在当前目录下,标志位为1,标志位不能为255
	//a.GetFileList("F:\\demo\\红外简易demo201981206\\样本\\样本\\0", "F:\\demo\\红外简易demo201981206\\样本\\样本\\train.txt", ".jpg", 0);
	读取文件
	//a.readTrainFileList();
	提取HOG特征
	//a.processHogFeature();
	训练分类器
	//a.trainLibSVM();

	将测试正样本写入txt,255代表标志位为空
	//a.GetFileList("F:\\demo\\红外简易demo201981206\\样本\\样本\\0", a.m_testImageFile.c_str(), ".jpg", 255);
	//读取测试图片
	a.readTestFileList();
	
	string path = a.m_basePath + a.m_SVMModel;
	a.m_svm = svm_load_model(path.c_str());
	
	for (int i = 0; i != a.m_testImageList.size(); i++)
	{
		Mat src;
		src = imread((a.m_testImageList[i]).c_str(), 0);
		if (src.empty())
		{
			continue;
		}

		double  starttime = getdoubletime();
		char info[1024];

		double prob[2] = {0};
		float temp = a.testLibSVM(src,prob);
		double  endtime = getdoubletime();
		sprintf(info, "帧序号为%d  耗时%f\n", i, (endtime - starttime) * 1000);
		//timeuserlog("testsvm", info);
		cout << info<< endl;
		//保存置信度
		a.m_prob.push_back(prob);
		//保存分类结果标签
		a.m_resultlable.push_back(temp);
		
	}

	
	svm_free_model_content(a.m_svm);
	
	system("pause");
	return 0;

}
解压到sln一级目录项目属性 C/C++ 附加包含目录 填写Libs/x86/opencv_v3.4.0/include路径 属性链接器,所有选项附加目录填写Libs/x86/opencv_v3.4.0/lib 附加依赖项: opencv_aruco340.lib;f.lib;opencv_bgsegm340.lib;opencv_bgsegm340d.lib;opencv_bioinspired340.lib;opencv_bioinspired340d.lib;opencv_calib3d340.lib;opencv_calib3d340d.lib;opencv_ccalib340.lib;opencv_ccalib340d.lib;opencv_core340.lib;opencv_core340d.lib;opencv_datasets340.lib;opencv_datasets340d.lib;opencv_dnn340.lib;opencv_dnn340d.lib;opencv_dpm340.lib;opencv_dpm340d.lib;opencv_face340.lib;opencv_face340d.lib;opencv_features2d340.lib;opencv_features2d340d.lib;opencv_flann340.lib;opencv_flann340d.lib;opencv_fuzzy340.lib;opencv_fuzzy340d.lib;opencv_highgui340.lib;opencv_highgui340d.lib;opencv_imgcodecs340.lib;opencv_imgcodecs340d.lib;opencv_imgproc340.lib;opencv_imgproc340d.lib;opencv_img_hash340.lib;opencv_img_hash340d.lib;opencv_line_descriptor340.lib;opencv_line_descriptor340d.lib;opencv_ml340.lib;opencv_ml340d.lib;opencv_objdetect340.lib;opencv_objdetect340d.lib;opencv_optflow340.lib;opencv_optflow340d.lib;opencv_phase_unwrapping340.lib;opencv_phase_unwrapping340d.lib;opencv_photo340.lib;opencv_photo340d.lib;opencv_plot340.lib;opencv_plot340d.lib;opencv_reg340.lib;opencv_reg340d.lib;opencv_rgbd340.lib;opencv_rgbd340d.lib;opencv_saliency340.lib;opencv_saliency340d.lib;opencv_shape340.lib;opencv_shape340d.lib;opencv_stereo340.lib;opencv_stereo340d.lib;opencv_stitching340.lib;opencv_stitching340d.lib;opencv_structured_light340.lib;opencv_structured_light340d.lib;opencv_superres340.lib;opencv_superres340d.lib;opencv_surface_matching340.lib;opencv_surface_matching340d.lib;opencv_text340.lib;opencv_text340d.lib;opencv_tracking340.lib;opencv_tracking340d.lib;opencv_video340.lib;opencv_video340d.lib;opencv_videoio340.lib;opencv_videoio340d.lib;opencv_videostab340.lib;opencv_videostab340d.lib;opencv_xfeatures2d340.lib;opencv_xfeatures2d340d.lib;opencv_ximgproc340.lib;opencv_ximgproc340d.lib;opencv_xobjdetect340.lib;opencv_xobjdetect340d.lib;opencv_xphoto340.lib;opencv_xphoto340d.lib;
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值