简述
通过DTree进行带标签数据的训练,保存,读取,预测,可用于场景:图像特征,文本数据,数据挖掘等的分类。
样本
序号 | 年纪 | 薪水 | 有房 | 有车 | 信贷 | 能否通过 |
1 | Y | L | ||||
2 | Y | L | ||||
3 | Y | M | ||||
4 | Y | M | ||||
5 | Y | H | ||||
6 | Y | M | ||||
7 | M | |||||
8 | M | |||||
9 | M | |||||
10 | M | |||||
11 | M | |||||
12 | M | |||||
13 | O | |||||
14 | O | |||||
15 | O | |||||
16 | O | |||||
17 | O | |||||
18 | O | |||||
19 | O |
代码
static const char* var_desc[] =
{
"Age (young = Y, middle = M, old = O)",
"Salary (low = L, medium = M, high = H)",
"Own_House (false = N, true = Y)",
"Own_Car (false = N, true = Y)",
"Credit_Rating (fair = F, good = G, excellent = E)",
0
};
int DTree_example()
{
/*
#include <opencv2\core\core.hpp>
#include <opencv2\highgui\highgui.hpp>
#include <opencv2\imgproc\imgproc.hpp>
#include <opencv2\ml\ml.hpp>
#include <iostream>
using namespace cv;
using namespace std;
*/
/*样本*/
float trainingData[19][5] = {
{ 'Y', 'L', 'N', 'N', 'F' },
{ 'Y', 'L', 'Y', 'N', 'G' },
{ 'Y', 'M', 'Y', 'N', 'G' },
{ 'Y', 'M', 'Y', 'Y', 'G' },
{ 'Y', 'H', 'Y', 'Y', 'G' },
{ 'Y', 'M', 'N', 'Y', 'G' },
{ 'M', 'L', 'Y', 'Y', 'E' },
{ 'M', 'H', 'Y', 'Y', 'G' },
{ 'M', 'L', 'N', 'Y', 'G' },
{ 'M', 'M', 'Y', 'Y', 'F' },
{ 'M', 'H', 'Y', 'Y', 'E' },
{ 'M', 'M', 'N', 'N', 'G' },
{ 'O', 'L', 'N', 'N', 'G' },
{ 'O', 'L', 'Y', 'Y', 'E' },
{ 'O', 'L', 'Y', 'N', 'E' },
{ 'O', 'M', 'N', 'Y', 'G' },
{ 'O', 'L', 'N', 'N', 'E' },
{ 'O', 'H', 'N', 'Y', 'F' },
{ 'O', 'H', 'Y', 'Y', 'E' }
};
Mat trainingDataMat(19, 5, CV_32FC1, trainingData);
/*标签*/
float responses[19] = {
'N', 'N', 'Y', 'Y', 'Y',\
'N', 'Y', 'Y', 'N', 'N',\
'Y', 'N', 'N', 'Y', 'Y',\
'N', 'N', 'N', 'Y' };
Mat responseMat(19, 1, CV_32FC1, responses);
/*决策树参数*/
float priors[5] = { 1, 1 };
CvDTreeParams params(
15,//
1,
0,
false,
25,
0,
false,
false,
priors
);
Mat varTypeMat(6, 1, CV_8U, Scalar::all(1));
/*训练样本,构建决策树*/
CvDTree* dtree = new CvDTree();
dtree->train(
trainingDataMat,
CV_ROW_SAMPLE,
responseMat,
Mat(),
Mat(),
varTypeMat,
Mat(),
params
);
/*计算每个属性的重要程度并打印出来*/
const CvMat* var_importance = dtree->get_var_importance();
for (size_t i = 0; i < var_importance->cols*var_importance->rows; i++)
{
double val = var_importance->data.db[i];
char buf[100];
int len = (int)(strchr(var_desc[i], '(') - var_desc[i] - 1);
strncpy(buf, var_desc[i], len);
buf[len] = '\0';
printf("%s", buf);
printf(":%g%%\n", val * 100);
}
/*模型保存*/
dtree->save("dtree.xml");
/*模型加载
CvDTree* dtree = new CvDTree();
*/
dtree->load("dtree.xml");
/*预测*/
float myData[5] = { 'M', 'H', 'Y', 'N', 'F' };
Mat myDataMat(5, 1, CV_32FC1, myData);
double r = dtree->predict(myDataMat, Mat(), false)->value;
cout << endl << "result:" << (char)r << endl;
return 0;
}