一种常见的数据挖掘的算法SPRINT算法的简单实现

#include <iostream>
#include <fstream>
#include <string>
#include <vector>
#include <math.h>
#include <map>
#include <set>
#include <list>
#include <algorithm>
using namespace std;
map<int,float> xmap;
typedef struct attrelem
{
	int attridx;
	float attrval;
	string classlabel;
	int    rid;


	attrelem():attridx(0),rid(0){}
};
bool operator==(const attrelem &lhs,const attrelem &rhs)
{
	return (lhs.attrval == rhs.attrval && lhs.classlabel == rhs.classlabel
		&& lhs.rid == rhs.rid);

}

typedef struct treenode
{
	string category;
	vector< vector<attrelem> > vecattrlist;
	float x;
	int splittattridx;
	treenode *left;
	treenode *right;
	treenode():left(NULL),right(NULL),splittattridx(-1){}

}*node_ptr;
float getx(float x1,float x2)
{	
	return (x1+x2)/2.0;

}
void trimspace(string & str)
{
	for (string::iterator it = str.begin(); it != str.end(); it++)
	{
		if (*it == ' ')
		{
			it = str.erase(it);			
		}
	}
}
void getsplitstring(string str,char splitter,vector<string> & vecout)
{
	size_t pos = str.find_first_of(splitter,0);
	size_t beg = 0;
	while (pos != string::npos)
	{
		vecout.push_back(str.substr(beg,pos - beg));
		beg = pos + 1;
		pos = str.find_first_of(splitter,beg);
	}

	vecout.push_back(str.substr(beg,str.size() - beg));
}
float square(float x)
{
	return x*x;
}
int getisplitter(attrelem *ret,vector< vector<attrelem> > &vecattrlist)
{
	int isplitter = -1;
	for (int i = 0;i<vecattrlist.size();i++)
	{
		for (int j = 0;j<vecattrlist[i].size();j++)
		{
			if (ret == &vecattrlist[i][j])
			{
				isplitter = i;
				break;
			}

		}
		if (isplitter != -1)
			break;
	}
	return isplitter;

}
float getgini(float x,vector<attrelem> & attrlist)
{
	
	vector<string> lessequalval;
	vector<string>  greaterlval;
	for (int i = 0;i<attrlist.size();i++)
	{
		if (attrlist[i].attrval<=x)
		{
			lessequalval.push_back(attrlist[i].classlabel);

		}
		else
		{
			greaterlval.push_back(attrlist[i].classlabel);
		}
	}

	map<string,int> catecnt;
	for (int i=0;i<lessequalval.size();i++)
	{
		map<string,int>::iterator it = catecnt.find(lessequalval[i]);
		if (it == catecnt.end())
		{
			catecnt.insert(make_pair(lessequalval[i],0));
			catecnt[lessequalval[i]]++;
		}
		else
		{
			catecnt[lessequalval[i]]++;
		}
	}

	float lessequalgini = 0.0;
	for (map<string,int>::iterator it = catecnt.begin();
		it != catecnt.end();it++)
	{
		lessequalgini = lessequalgini + square((float)it->second/(float)lessequalval.size());
	}
	lessequalgini = 1 - lessequalgini;

	catecnt.clear();
	for (int i=0;i<greaterlval.size();i++)
	{
		map<string,int>::iterator it = catecnt.find(greaterlval[i]);
		if (it == catecnt.end())
		{
			catecnt.insert(make_pair(greaterlval[i],0));
			catecnt[greaterlval[i]]++;
		}
		else
		{
			catecnt[greaterlval[i]]++;
		}
	}

	float greatergini = 0.0;
	for (map<string,int>::iterator it = catecnt.begin();
		it != catecnt.end();it++)
	{
		greatergini = greatergini + square((float)it->second/(float)greaterlval.size());
	}
	greatergini = 1 - greatergini;

	float gini = ((float)lessequalval.size()/(float)attrlist.size()) * lessequalgini + ((float)greaterlval.size()/(float)attrlist.size()) * greatergini;
	return gini;	

}

int getmingini(vector< vector<attrelem> > &vecattrlist)
{
	

	int isplitter = -1;
	float mingini = getgini(getx(vecattrlist[0][0].attrval,vecattrlist[0][1].attrval),vecattrlist[0]);
	isplitter = 0;
	
	for (int i = 0;i<vecattrlist.size();i++)
	{
		
		for (int j=0;j<vecattrlist[i].size() -1 ;j++)
		{
			if (vecattrlist[i][j].attrval != vecattrlist[i][j+1].attrval)
			{
				float x = getx(vecattrlist[i][j].attrval,vecattrlist[i][j+1].attrval);
				float gini = getgini(x,vecattrlist[i]);
				if (mingini > gini)
				{
					mingini = gini;
					isplitter = i;
					xmap[vecattrlist[i][j].attridx] = x;
				}
			}
		}
		
	}
	return isplitter;


}
bool isnodepure(vector<attrelem> & attrlist)
{

	bool pure = true;
	string classlabel = attrlist[0].classlabel;
	for (int i = 0;i<attrlist.size();i++)
	{
		if (classlabel != attrlist[i].classlabel)
		{
			pure = false;
			break;
		}

	}
	return pure;
}


vector< vector<attrelem> >::iterator getit(int isplitter, vector< vector<attrelem> > &vecattrlist)
{
	
	int i=0;
	for (vector< vector<attrelem> >::iterator it =vecattrlist.begin();it!=vecattrlist.end();
		it++,i++)
	{
		if (i == isplitter)
			return it;
	}

}
int splitattrlist(vector< vector<attrelem> > &vecattrlist,vector< vector<attrelem> > &left,vector< vector<attrelem> > &right)
{

	
	int isplitter = getmingini(vecattrlist);
	int ret = vecattrlist[isplitter][0].attridx;

	while(left.empty() || right.empty() || left[0].empty() || right[0].empty())
	{

		if (!left.empty() && !right.empty() && (!left[0].empty() || !right[0].empty()) )
		{
			vecattrlist.erase(getit(isplitter,vecattrlist));

			isplitter = getmingini(vecattrlist);			
		}
		left.clear();
		right.clear();
		for (int i = 0;i<vecattrlist.size();i++)
		{
			vector<attrelem> attrleft, attrright;
			if (i == isplitter)
				continue;

			float x = xmap[vecattrlist[isplitter][0].attridx];
			for (int j = 0;j<vecattrlist[i].size();j++)
			{
				if (vecattrlist[isplitter][j].attrval <= x /*&& vecattrlist[isplitter][j].rid == vecattrlist[i][j].rid*/)
				{
					attrleft.push_back(vecattrlist[i][j]);
				}
				else if (vecattrlist[isplitter][j].attrval > x /*&& vecattrlist[isplitter][j].rid == vecattrlist[i][j].rid*/)
					attrright.push_back(vecattrlist[i][j]);
			}

			left.push_back(attrleft);
			right.push_back(attrright);

		}
	}	
	return ret;

}


void buildtree(vector< vector<attrelem> > &vecattrlist,node_ptr *tree)
{
	if (vecattrlist.size() == 0)
		return;
	if (isnodepure(vecattrlist[0]))
	{
		(*tree) = new treenode();
		(*tree)->category = vecattrlist[0][0].classlabel;
		(*tree)->left = NULL;
		(*tree)->right = NULL;
		return;

	}


	vector< vector<attrelem> >left,right;
	vector< vector<attrelem> > vecret;
	
	(*tree) = new treenode();
	(*tree)->vecattrlist = vecattrlist;
	(*tree)->splittattridx = splitattrlist(vecattrlist,left,right);
	(*tree)->x = xmap[(*tree)->splittattridx];

	


	buildtree(left,&(*tree)->left);

	buildtree(right,&(*tree)->right);


}
bool isvalinlist(float val,vector<float> & vallist)
{
	for (int i = 0;i<vallist.size();i++)
	{
		if (val == vallist[i])
			return true;
	}
	return false;
}
string gettype(vector<string> data,node_ptr tree)
{
	if (!tree)
		return "";


	if (!tree->category.empty())
		return tree->category;

	float val = atof(data[tree->splittattridx].c_str());

	
	if (val <= tree->x)
	{
		return gettype(data,tree->left);		
	}
	else	
		return gettype(data,tree->right);


}

void outputtree(treenode *root,int level)
{
	if (!root)
		return;
	for (int i =0;i<level;i++)
	{
		cout<<" ";
	}

	if (root->category.empty())
	{
		cout<<root->splittattridx+1<<":<"<<root->x<<">"<<endl;
	}else
	{
		cout<<"Class:"<<root->category<<endl;
		return;
	}
	/*cout<<"value:"<<root->leftvalue[0]<<endl;*/
	outputtree(root->left,level+1);
	/*cout<<"value:"<<root->rightvalue[0]<<endl;*/
	outputtree(root->right,level+1);

}
bool lessthan(const attrelem & left,const attrelem & right)
{
	if (left.attrval < right.attrval)
		return true;
	else
		return false;
}
int main(int argc, char* argv[])
{
	if (argc < 3)
	{
		cout<<"Usage: "<<argv[0]<<" [filename].test [filename].train"<<endl;
		return false;
	}

	ifstream input(argv[2],ios_base::in);
	if (!input)
	{
		cout<<"Error: open file failed."<<endl;
		return false;
	}
	vector< vector<string> > vectrain;
	while (input.peek() != EOF)
	{
		char oneline[512];
		input.getline(oneline,512);
		string line = oneline;
		trimspace(line);
		vector<string> vecout;		
		getsplitstring(line,',',vecout);
		vectrain.push_back(vecout);
	}


	vector< vector<attrelem> > vecattrlist;
	for (int j=0;j<16;j++)
	{
		vector<attrelem> attrlist;
		for (int i = 0;i<vectrain.size();i++)
		{
			attrelem elem;
			elem.attrval = atof(vectrain[i][j].c_str());
			elem.classlabel = vectrain[i][16];
			elem.rid = i;
			elem.attridx = j;
			attrlist.push_back(elem);
		}
		//sort(attrlist.begin(),attrlist.end(),lessthan);
		vecattrlist.push_back(attrlist);
	}

	node_ptr root;

	buildtree(vecattrlist,&root);

	//outputtree(root,0);
	vector< vector<string> > veczoo;

	ifstream zooinput(argv[1]);

	if (!zooinput)
	{
		cout<<"Error: open file failed."<<endl;
		return false;
	}


	while (zooinput.peek() != EOF)
	{
		char oneline[512];
		zooinput.getline(oneline,512);
		string line = oneline;
		trimspace(line);
		vector<string> vecout;		
		getsplitstring(line,',',vecout);

		string type = gettype(vecout,root);
		cout<<type<<endl;
	}
	return true;
}

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值