c++版id3决策树实现

上一篇文章实现了Python的决策树,借用上篇文章的算法思路实现了c++版的算法。数据结构是自己设计实现,肯定有很多不好的地方希望各位高手能给出些建议,这是我第一次使用c++来实现这样大的程序。程序中用到的数据是借用网上一哥们二的。

DataSet.h

#ifndef SAMPLE_H
#define SAMPLE_H
#include<vector>
#include<string>
#include<set>
using namespace std;
class DataSet
{
private:
	void ReadData(ifstream &in);//读取数据
	vector<string> SplitLine(const string &str);//处理从文件读入的每一行数据
	double Entropy(const vector<int> &v);

public:
	struct Sample
	{
		vector<string> attributes;
		string targetAttributes;
	};

	//存储数据的元素
	vector<string> attributesNames; //变量名称
	vector<Sample> dataSet; //数据集合

	DataSet();//无参数构造函数
	DataSet(const string &fileName);//从文件构造数据集
	~DataSet();
	void Print();//打印数据集,用于直观显示存储的数据。
	double Gain(const string &featureName);//计算信息增益

	DataSet SplitDataSet(const string&featureName, const string &featureValue);
	string BestSplitFeature();//最大信息增益的属性
	string CommomTargetValue();
	bool IsSameTarget();
	int GetAttributeIndex(const string &attributeName);

};

#endif

DataSet.cpp

#include<vector>
#include<iostream>
#include<fstream>
#include<string>
#include<map>
#include<set>
#include<numeric>
#include"DataSet.h"

using namespace std;
//默认构造函数
DataSet::DataSet()
{
	
}
//从文件构建数据集
DataSet::DataSet(const string &fileName)
{
	ifstream in(fileName);
	if (!in)
	{
		cout << "文件打开失败";
	}
	else
	{
		ReadData(in);
	}
}
//析构函数
DataSet::~DataSet()
{
	if (!attributesNames.empty())
	{
		 dataSet.clear();
		 attributesNames.clear();
	}
}

//读取数据
void DataSet::ReadData(ifstream &in)
{
	string str;
	vector<string> tempV;
	getline(in, str);
	tempV= SplitLine(str);//调用splitLine;
	attributesNames.assign(tempV.begin() + 1, tempV.end() - 1);
	while (getline(in, str))
	{
		Sample s;
		tempV = SplitLine(str);
		s.attributes.assign(tempV.begin() + 1, tempV.end() - 1);
		s.targetAttributes = tempV[tempV.size() - 1];
		dataSet.push_back(s);
	}
}
//切分从文件读取的行
vector<string> DataSet::SplitLine(const string &str)
{
	vector<string> v;
	bool isFirstBlank = true;
	string::size_type pos=0;
	for (string::size_type i = 0; i != str.size(); ++i)
	{
		if (str[i] == '\t'&&isFirstBlank)
		{
			v.push_back(string(str,pos,i-pos));
			isFirstBlank = false;
		}
		if (str[i] != '\t' &&isFirstBlank == false)
		{
			isFirstBlank = true;
			pos = i;
		}
	}
	if (isFirstBlank)
	{
		v.push_back(string(str, pos, str.size() - pos));
	}
	return v;
}

//打印数据集
void DataSet::Print()
{
	if (attributesNames.size() == 0)
	{
		cout << "数据集为空" << endl;
		return;
	}

	for (vector<string>::iterator nIter = attributesNames.begin(); nIter != attributesNames.end();++nIter)
	{
		cout << *nIter<<"\t";
	}
	cout << endl;

	for (vector<Sample>::iterator sIter = dataSet.begin(); sIter != dataSet.end(); ++sIter)
	{
		for (vector<string>::iterator aIter = sIter->attributes.begin(); aIter != sIter->attributes.end(); ++aIter)
		{
			cout << *aIter << "\t";
		}
		cout << sIter->targetAttributes;
		cout << endl;
	}
}

//信息增益计算
double DataSet::Gain(const string &featureName)
{
	vector<string>::iterator findIter = find(attributesNames.begin(), attributesNames.end(), featureName);
	if (findIter == attributesNames.end()) throw "参数错误";
	
	vector<string>::size_type index;//数据集的列标签
	for (index = 0; index != attributesNames.size(); ++index)
	{
		if (attributesNames[index] == featureName){
			break;
		}
	}

	map<string, int> targetMap;//统计各个属性的样本数
	map<string,map<string, int> > featureMap;//键为
	for (vector<Sample>::iterator sIter = dataSet.begin(); sIter != dataSet.end(); ++sIter)
	{
		++targetMap[sIter->targetAttributes];
		++featureMap[sIter->attributes[index]][sIter->targetAttributes];
	}

	vector<int> targetCount;//统计目标变量
	
	for (map<string, int>::iterator tarIter = targetMap.begin(); tarIter != targetMap.end(); ++tarIter)
	{
		targetCount.push_back(tarIter->second);
	}
	double gain = Entropy(targetCount);//总的熵
	
	int sTotal = dataSet.size();//总的样本数
	for (map<string, map<string, int> >::iterator featIter = featureMap.begin(); featIter != featureMap.end(); ++featIter)
	{
		vector<int> featureCout;
		for (map<string, int>::iterator featIter2 = featIter->second.begin(); featIter2 != featIter->second.end(); ++featIter2)
		{
			featureCout.push_back(featIter2->second);
		}
		int s = accumulate(featureCout.begin(), featureCout.end(),0);//特征出现的总次数
		
		gain -= 1.0 * s / sTotal*Entropy(featureCout);

	}
	return gain;
}

//计算信息熵
double DataSet::Entropy(const vector<int> &v)
{
	double entropy=0.0;
	int totalNum = accumulate(v.begin(),v.end(),0);
	for (vector<int>::size_type i = 0; i != v.size(); ++i)
	{
		int temp = v[i];
		double p =1.0* temp/totalNum;  //注意类型转化
		entropy -= p*log2(p);
	}
	return entropy;
}

//根据属性和属性的值为featureValue的子集
DataSet DataSet::SplitDataSet(const string &featureName, const string &featureValue)
{
	vector<string>::iterator fIter = find(attributesNames.begin(), attributesNames.end(), featureName);
	if (attributesNames.size() == 0|| fIter==attributesNames.end()) throw "参数错误";

	DataSet children;              //数据集的子集
	vector<string>::size_type index;
	for (index = 0; index != attributesNames.size(); ++index)//找到属性标签的序号
	{
		if (attributesNames[index] == featureName)
		{
			break;
		}
	}

	vector<Sample>::iterator dIter;
	for (dIter = dataSet.begin(); dIter != dataSet.end(); ++dIter)
	{
		
		if (dIter->attributes[index] == featureValue)
		{
			children.dataSet.push_back(*dIter);//把满足条件的样本放入childrenSet
		}
	}
	children.attributesNames = this->attributesNames;
	
	//去除childrenSet已经使用过的属性
	vector<string>::iterator eNiter = find(children.attributesNames.begin(), children.attributesNames.end(), featureName);
	if (eNiter!=children.attributesNames.end())
	{
		children.attributesNames.erase(eNiter);
	}
	vector<Sample>::iterator chilDataIter = children.dataSet.begin();
	for (; chilDataIter != children.dataSet.end(); ++chilDataIter)
	{
		vector<string>::iterator eFiter = find(chilDataIter->attributes.begin(), chilDataIter->attributes.end(), featureValue);
		if (eFiter != chilDataIter->attributes.end())
		{
			chilDataIter->attributes.erase(eFiter);
		}
	}
	return children;
}

//选择子集
string DataSet::BestSplitFeature()
{
	double maxGain = 0.0;//最大信息增益
	string bestSplit;    //具有最大信息增益的属性
	vector<string>::iterator tIter = attributesNames.begin();
	for (; tIter != attributesNames.end(); ++tIter)
	{
		if (maxGain < Gain(*tIter))
		{
			maxGain = Gain(*tIter);
			bestSplit = *tIter;
		}
	}
	return bestSplit;
}

//判断目标属性是否唯一
bool DataSet::IsSameTarget()
{
	set<string> targetContine;
	vector<Sample>::iterator sIter = dataSet.begin();
	for (; sIter != dataSet.end(); ++sIter)
	{
		targetContine.insert(sIter->targetAttributes);
	}
	if (targetContine.size() == 1)
	{
		return true;
	}
	else
	{
		return false;
	}
}

//返回最长出现的属性
string DataSet::CommomTargetValue()
{
	map<string, int> targetValue;
	vector<Sample>::iterator dIter = dataSet.begin();
	for (; dIter != dataSet.end(); ++dIter)
	{
		++targetValue[dIter->targetAttributes];
	}

	map<string, int>::iterator mIter = targetValue.begin();
	string commomTarget;
	int maxValue = 0;
	for (; mIter != targetValue.end(); ++mIter)
	{
		if (maxValue < mIter->second)
		{
			maxValue = mIter->second;
			commomTarget = mIter->first;
		}
	}
	return commomTarget;
}

int DataSet::GetAttributeIndex(const string &attributeName)
{
	vector<string>::size_type index;
	for (index = 0; index != attributesNames.size(); ++index)
	{
		if (attributesNames[index] == attributeName)
		{
			return index;
		}
	}

	if (index == attributesNames.size())
	{
		throw "属性不存在";
	}
}

DesctionTree.h

#ifndef DESCITION_H
#define DESCITION_H
#include<string>
#include<vector>
#include<map>
#include"DataSet.h"
using  std::string;
using std::vector;
class DesctionTree
{
private:
	struct Node
	{
		string value;
		map<string, Node* > children;
	};

	Node *root;

	Node* CreateNode(DataSet &trainSet);
	void ShowNode(Node *rNode,int level);
	

	string ClassOneSample(Node* rNode,vector<string> &v,vector<string> &attributeNames);
public:
	DesctionTree();
	//~DesctionTree();
	void CreateTree(DataSet &trainSet);
	vector<string> ClassTest(DataSet &examples);
	void ShowTree();
	void ClearTree(Node* rNode);
	

};


#endif

Desction.cpp

#include<vector>
#include<string>
#include<iostream>
#include"DesctionTree.h"

using namespace std;

DesctionTree::DesctionTree()
{
	root = NULL;
}


void DesctionTree::CreateTree(DataSet &trainSet)
{
	root = new Node;
	root = CreateNode(trainSet);
}

void DesctionTree::ShowTree()
{
	if (root != NULL)
	{
		ShowNode(root,1);
	}
	
}

DesctionTree::Node* DesctionTree::CreateNode(DataSet &trainSet)
{
	if (trainSet.dataSet.size() == 0)
	{
		cout << "数据为空" << endl;
		return NULL;
	}

	Node* rootNode= new Node;
	if (trainSet.attributesNames.size() == 0)
	{
		rootNode->value = trainSet.CommomTargetValue();
		return rootNode;
	}

	if (trainSet.IsSameTarget())
	{
		vector<DataSet::Sample>::iterator sIter = trainSet.dataSet.begin();
		rootNode->value = sIter->targetAttributes;
		return rootNode;
	}

	string bestSplitAttributes = trainSet.BestSplitFeature();
	rootNode->value = bestSplitAttributes;
	vector<string>::size_type index;
	for (index = 0; index != trainSet.attributesNames.size();++index)
	{
		if (trainSet.attributesNames[index] == bestSplitAttributes)
		{
			break;
		}
	}

	set<string> valueSet;
	vector<DataSet::Sample>::iterator sIter2 = trainSet.dataSet.begin();
	for (; sIter2 != trainSet.dataSet.end(); ++sIter2)
	{
		valueSet.insert(sIter2->attributes[index]);
	}

	set<string>::iterator setIter = valueSet.begin();
	for (; setIter != valueSet.end(); ++setIter)
	{
		
		DataSet childSet = trainSet.SplitDataSet(bestSplitAttributes,*setIter);
		
		rootNode->children[*setIter] = CreateNode(childSet);
	}
	return rootNode;
}

void DesctionTree::ShowNode(Node *rNode,int level)
{
	cout << rNode->value << '\n';
	if (rNode->children.size() == 0)
	{
		return;
	}
	map<string, Node*>::iterator mIter = rNode->children.begin();

	for (; mIter != rNode->children.end(); ++mIter)
	{
		for (int j = 0; j < level; ++j)
		{
			cout << '\t';
		}
		cout << mIter->first << "->";
		ShowNode(mIter->second,level+1);
	}
}

vector<string> DesctionTree::ClassTest(DataSet &examples)
{
	vector<string> retDesction;//决策属性
	vector<DataSet::Sample>::iterator sIter=examples.dataSet.begin();

	for (; sIter != examples.dataSet.end(); ++sIter)
	{
		
		retDesction.push_back(ClassOneSample(root, sIter->attributes, examples.attributesNames));
	}
	return retDesction;
}


string DesctionTree::ClassOneSample(Node* rNode,vector<string> &v, vector<string> &attributeNames)
{
	
	vector<string>::size_type index;
	if (rNode->children.size() == 0)
	{
		
		return rNode->value;
	}

	vector<string>::iterator result = find(attributeNames.begin(), attributeNames.end(), rNode->value);
	if (result == attributeNames.end())
	{
		return " ";
	}
	index = result - attributeNames.begin();
	
	map<string, Node*>::iterator mIter = rNode->children.begin();
	for (; mIter != rNode->children.end(); ++mIter)
	{
		if (mIter->first == v[index])
		{
			vector<string> tempV(v);
			vector<string> tempAttributeNames(attributeNames);
			tempV.erase(tempV.begin()+index);
			tempAttributeNames.erase(tempAttributeNames.begin() + index);
				
			
			return ClassOneSample(mIter->second, tempV, tempAttributeNames);
			
		}
	}
	
}

/*
void DesctionTree::ClearTree(Node *rNode)
{
	if (rNode->children.empty())
	{
		delete rNode;
		rNode = NULL;
	}

	for (map<string, Node*>::iterator mIter = rNode->children.begin(); mIter != rNode->children.end(); ++mIter)
	{
		ClearTree(mIter->second);
	}
}
*/

main.cpp

/**********************************
** id3决策树
** @author:郑午
** @time:2014-06-19
**
*********************************/

#include<iostream>
#include<vector>
#include<set>
#include<iterator>
#include"DataSet.h"
#include"DesctionTree.h"
using namespace std;

int main()
{
	//训练
	DataSet dataMat("data.txt");
	DesctionTree id3Tree;
	id3Tree.CreateTree(dataMat);
	cout << "树结构:"<<endl;
	id3Tree.ShowTree();

	//测试
	DataSet testSet("test.txt");
	vector<string> result=id3Tree.ClassTest(testSet);
	cout << "测试结果:" << endl<<endl;
	copy(result.begin(), result.end(), ostream_iterator<string>(cout, " "));
	cout << endl<<endl;
	system("pause");
}



数据文件

Day OutlookTemperate HumidityWind PlayTennis
D1 Sunny Hot High Weak No
D2 Sunny Hot High Strong No
D3 Overcast Hot High Weak Yes
D4 Rain Mild High Weak Yes
D5 Rain Cool Normal Weak Yes
D6 Rain Cool Normal Strong No
D7 Overcast Cool Normal Strong Yes
D8 Sunny Mild High Weak No
D9 Sunny Cool Normal Weak Yes
D10 Rain Mild Normal Weak Yes
D11 Sunny Mild Normal Strong Yes
D12 Overcast Mild High Strong Yes
D13 Overcast Hot Normal Weak Yes
D14 Rain Mild High Strong No


测试数据是训练数据的一个子集。

运行结果:


评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值