opencv svm 多分类问题

svm一般都是两分类的问题,有时候我们需要多分类的时候 以下代码就派上用场了。

话不多说,直接上代码。

// svm_test.cpp : 定义控制台应用程序的入口点。
//

#include "stdafx.h"
#include <iostream>
#include <fstream>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>

using namespace cv;
using namespace std;

#define skyface_API	extern __declspec(dllexport)
skyface_API int sex_detect(vector<float> &feats, const char* modpth);

Mat traindata(string path, int num)
{
	vector<vector<float>> data(num, vector<float>(2048, 0));
	ifstream ifs;
	ifs.open(path);
	for (int i = 0; i < num; i++)
	{
		for (int j = 0; j < 2048; j++)
		{
			ifs >> data[i][j];
		}
	}
	ifs.close();
	Mat class_n_data(data.size(), data.at(0).size(), CV_32FC1);
	for (int i = 0; i < data.size(); i++)
		for (int j = 0; j < data.at(0).size(); j++)
			class_n_data.at<float>(i, j) = data.at(i).at(j);
	return class_n_data;
}

Mat get_traindata3(Mat class1, Mat class2, Mat class3)
{
	Mat traindata(class1.rows + class2.rows + class3.rows , 2048, CV_32FC1);
	Mat tmp = traindata.rowRange(0, class1.rows);
	class1.copyTo(tmp);
	tmp = traindata.rowRange(class1.rows, class1.rows + class2.rows);
	class2.copyTo(tmp);
	tmp = traindata.rowRange(class1.rows + class2.rows, class1.rows + class2.rows + class3.rows);
	class3.copyTo(tmp);
	cout << "获取到训练数据!" << endl;
	return traindata;
}

Mat get_labels3(Mat class1, Mat class2, Mat class3)
{
	Mat labels(class1.rows + class2.rows + class3.rows , 1, CV_32FC1);
	labels.rowRange(0, class1.rows).setTo(1);
	labels.rowRange(class1.rows, class1.rows + class2.rows).setTo(2);
	labels.rowRange(class1.rows + class2.rows, class1.rows + class2.rows + class3.rows).setTo(3);
	return labels;
}


void trainSVM(Mat traindata, Mat labels, string modelpth)
{
	//------------------------ 2. Set up the support vector machines parameters --------------------
	CvSVMParams params;
	params.svm_type = SVM::C_SVC;
	params.C = 0.1;
	params.kernel_type = SVM::LINEAR;
	params.term_crit = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);
	//------------------------ 3. Train the svm ----------------------------------------------------
	cout << "Starting training process" << endl;
	CvSVM svm;
	svm.train(traindata, labels, Mat(), Mat(), params);
	cout << "Finished training process" << endl;
	svm.save("../data/model_AGE.txt");
}

int sex_detect(vector<float> &feats, const char* modpth)
{
	CvSVM SVM;
	SVM.load(modpth);
	int i;
	float* testdata = new float[2048];
	for (int i = 0; i < 2048; i++)
	{
		testdata[i] = feats[i];
	}
	Mat test = Mat(1, 2048, CV_32FC1, testdata);
	float result = SVM.predict(test);
	delete[] testdata;
	return result;
}
int main()
{

  //int labels[3]=[class1,class2,class3];
  
	Mat class1 = traindata("../data/feats_left.txt",40);
	Mat class2 = traindata("../data/feats_right.txt",36);
	Mat class3 = traindata("../data/feats_pos.txt",48);
	
	//Mat traindata = get_traindata(class1, class2);
	//Mat labels = get_labels(class1, class2);
	
	Mat traindata = get_traindata3(class1, class2, class3);
	Mat labels = get_labels3(class1, class2, class3);
	
	trainSVM(traindata, labels, "*");
	CvSVM SVM;
	SVM.load("../data/model_AGE.txt");
	ifstream ifs;
	float testdata[2048];
	ifs.open("../data/feats_test.txt");
		for (int i = 0; i < 2048; i++)
		{
			ifs >> testdata[i];
		}
		Mat test = Mat(1, 2048, CV_32FC1, testdata);
		float result = SVM.predict(test);
		if (result == 1)
			cout << "左偏30度" << endl;
		else if (result == 2)
			cout<< "右偏30度" <<endl;
		else if (result == 3)
		  cout<< "正脸" <<endl;

	ifs.close();
	system("pause");
}


  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值