去年的时候,使用过 C++ 版本的SVM 实现过基于无人机的道路检测,但是当时,对博大精深的 SVM 只是了解皮毛。最近,对 SVM 的基本公式及相关变体的公式,重新推导了一遍,并且分别用 Python 和 C++ 实现了一遍。此文,是用 C++ 实现的。
SVM 公式的推导是需要掌握的,其实,如果一步一步地推导,基本公式是不难推导的,比如目标函数啊,拉格朗日乘子法, 以及涉及的对偶问题、KKT 条件、SMO 算法等。当进一步往下推导,比如 核函数、软间隔以及正则化等,需要点耐心去理解。
数据集: http://download.csdn.net/download/wz2671/10172405
main.cpp
#include <Windows.h>
#include "SVM.h"
#include "matrix.h"
#include "mat.h"
#include <iostream>
#pragma comment(lib, "libmat.lib")
#pragma comment(lib, "libmx.lib")
using namespace std;
const int fn = 13;
const int sn1 = 59;
const int sn2 = 71;
const int sn3 = 48;
const int sn = 178;
int readData(double* &data, double* &label)
{
MATFile *pmatFile = NULL;
mxArray *pdata = NULL;
mxArray *plabel = NULL;
int ndir;//矩阵数目
//读取数据文件
pmatFile = matOpen("wine_data.mat", "r");
if (pmatFile == NULL) return -1;
/*获取.mat文件中矩阵的名称
char **c = matGetDir(pmatFile, &ndir);
if (c == NULL) return -2;
*/
pdata = matGetVariable(pmatFile, "wine_data");
data = (double *)mxGetData(pdata);
matClose(pmatFile);
//读取类标
pmatFile = matOpen("wine_label.mat", "r");
if (pmatFile == NULL) return -1;
plabel = matGetVariable(pmatFile, "wine_label");
label = (double *)mxGetData(plabel);
matClose(pmatFile);
}
int main()
{
doubl *data ;
double *label;
readData(data, label);
//需要注意从.mat文件中读取出的数据按列存储
double *d;
double *l;
SVM svm;
//第一组数据集与第二组数据集 预处理
l = new double[sn1 + sn2];
for(int i=0; i<sn1+sn2; i++)
{
if (fabs(label[i] - 2)<1e-3) l[i] = -1;
else l[i] = 1;
}
d = new double[(sn1 + sn2)*fn];
for (int i = 0; i < fn; i++)
{
for (int j = 0; j < sn1+sn2; j++)
{
d[j*fn + i] = data[i*sn + j];
}
}
/*
for (int i = 0; i < sn1 + sn2; i++)
{
for (int j = 0; j < fn; j++)
{
cout << d[i*fn + j] << ' ';
}
cout << endl;
}
*/
svm.initialize(d, l, sn1+sn2, fn);
svm.SMO();
cout << "数据集1和数据集2";
svm.show();
delete l;
delete d;
//第二组数据集与第三组数据集
l = new double[sn2 + sn3];
for (int i = sn1; i < sn1 + sn2 + sn3; i++)
{
if (fabs(label[i] - 2) < 1e-3) l[i-sn1] = 1;
else if (fabs(label[i] - 3) < 1e-3) l[i-sn1] = -1;
}
d = new double[(sn2 + sn3)*fn];
for (int i = 0; i < fn; i++)
{
for (int j = sn1; j < sn; j++)
{
d[(j - sn1)*fn + i] = data[i*sn + j];
}
}
svm.initialize(d, l , sn2+sn3, fn);
svm.SMO();
cout << "\n数据集2和数据集3";
svm.show();
delete l;
delete d;
//第一组数据集和第三组数据集
l = new double[sn1 + sn3];
for (int i = 0; i < sn1 + sn2 + sn3