svm一般都是两分类的问题,有时候我们需要多分类的时候 以下代码就派上用场了。
话不多说,直接上代码。
// svm_test.cpp : 定义控制台应用程序的入口点。
//
#include "stdafx.h"
#include <iostream>
#include <fstream>
#include <vector>
#include <opencv2/core/core.hpp>
#include <opencv2/highgui/highgui.hpp>
#include <opencv2/ml/ml.hpp>
using namespace cv;
using namespace std;
#define skyface_API extern __declspec(dllexport)
skyface_API int sex_detect(vector<float> &feats, const char* modpth);
Mat traindata(string path, int num)
{
vector<vector<float>> data(num, vector<float>(2048, 0));
ifstream ifs;
ifs.open(path);
for (int i = 0; i < num; i++)
{
for (int j = 0; j < 2048; j++)
{
ifs >> data[i][j];
}
}
ifs.close();
Mat class_n_data(data.size(), data.at(0).size(), CV_32FC1);
for (int i = 0; i < data.size(); i++)
for (int j = 0; j < data.at(0).size(); j++)
class_n_data.at<float>(i, j) = data.at(i).at(j);
return class_n_data;
}
Mat get_traindata3(Mat class1, Mat class2, Mat class3)
{
Mat traindata(class1.rows + class2.rows + class3.rows , 2048, CV_32FC1);
Mat tmp = traindata.rowRange(0, class1.rows);
class1.copyTo(tmp);
tmp = traindata.rowRange(class1.rows, class1.rows + class2.rows);
class2.copyTo(tmp);
tmp = traindata.rowRange(class1.rows + class2.rows, class1.rows + class2.rows + class3.rows);
class3.copyTo(tmp);
cout << "获取到训练数据!" << endl;
return traindata;
}
Mat get_labels3(Mat class1, Mat class2, Mat class3)
{
Mat labels(class1.rows + class2.rows + class3.rows , 1, CV_32FC1);
labels.rowRange(0, class1.rows).setTo(1);
labels.rowRange(class1.rows, class1.rows + class2.rows).setTo(2);
labels.rowRange(class1.rows + class2.rows, class1.rows + class2.rows + class3.rows).setTo(3);
return labels;
}
void trainSVM(Mat traindata, Mat labels, string modelpth)
{
//------------------------ 2. Set up the support vector machines parameters --------------------
CvSVMParams params;
params.svm_type = SVM::C_SVC;
params.C = 0.1;
params.kernel_type = SVM::LINEAR;
params.term_crit = TermCriteria(CV_TERMCRIT_ITER, (int)1e7, 1e-6);
//------------------------ 3. Train the svm ----------------------------------------------------
cout << "Starting training process" << endl;
CvSVM svm;
svm.train(traindata, labels, Mat(), Mat(), params);
cout << "Finished training process" << endl;
svm.save("../data/model_AGE.txt");
}
int sex_detect(vector<float> &feats, const char* modpth)
{
CvSVM SVM;
SVM.load(modpth);
int i;
float* testdata = new float[2048];
for (int i = 0; i < 2048; i++)
{
testdata[i] = feats[i];
}
Mat test = Mat(1, 2048, CV_32FC1, testdata);
float result = SVM.predict(test);
delete[] testdata;
return result;
}
int main()
{
//int labels[3]=[class1,class2,class3];
Mat class1 = traindata("../data/feats_left.txt",40);
Mat class2 = traindata("../data/feats_right.txt",36);
Mat class3 = traindata("../data/feats_pos.txt",48);
//Mat traindata = get_traindata(class1, class2);
//Mat labels = get_labels(class1, class2);
Mat traindata = get_traindata3(class1, class2, class3);
Mat labels = get_labels3(class1, class2, class3);
trainSVM(traindata, labels, "*");
CvSVM SVM;
SVM.load("../data/model_AGE.txt");
ifstream ifs;
float testdata[2048];
ifs.open("../data/feats_test.txt");
for (int i = 0; i < 2048; i++)
{
ifs >> testdata[i];
}
Mat test = Mat(1, 2048, CV_32FC1, testdata);
float result = SVM.predict(test);
if (result == 1)
cout << "左偏30度" << endl;
else if (result == 2)
cout<< "右偏30度" <<endl;
else if (result == 3)
cout<< "正脸" <<endl;
ifs.close();
system("pause");
}