STL-KMeans-Algorithm

/************************************************************************/
/*                        @Author	Amiber                              */
/*						  @Date	2012-11-25								*/
/*						  @brief: KMeans-Algorithm						*/
/************************************************************************/
#pragma warning(disable:4786)

#include<iostream>  //for cout,...
#include<algorithm> //for algorithms
#include<string>	// for
#include<set> //for set<int,less<int> >
#include<vector> //for vector<>
#include<cmath> //for sqrt,pow
#include<iterator> //for set<int,less<int> >::iterator
#include<fstream> //for ifstream
#include<sstream> // for istringstream
using namespace std;

/************************************************************************/
/*                         @Class Vector                                */
/*						   @brief the representation of point-vector    */
/************************************************************************/
class Vector
{
public:
	
	Vector()
	{
		clusterID = -1;
	}

	//achieve Vector[]
	double& operator [] (size_t index)
	{
		return arr[index];
	}

	//achieve Vector[] const
	double operator [] (size_t index) const 
	{
		return arr[index];
	}

	//update the value in index
	void setValue(int index,double value)
	{
		arr[index] = value;
	}

	//add the new-value 
	void push_back(double value)
	{
		arr.push_back(value);
	}

	//get the dimen 
	int size() const
	{
		return arr.size();
	}

	//calculate the dist between another pt and self
	double dist(const Vector& pt)
	{
		double sum = 0;
		for(size_t  i=0;i<pt.size();i++)
		{
			sum +=pow(pt[i]-arr[i],2.0);

		}

		return sqrt(sum);
	}

	//set which clusterID which Pt belongs
	void setClusterID(int clusterID)
	{
		this->clusterID = clusterID;
	}

	//get the Belonged ClusterID
	int getClusterID()
	{
		return clusterID;
	}

	//the detail information of pt
	void info_Vct()
	{
		cout<<"(";
		for(size_t i=0;i<arr.size();i++)
		{
			if(!i)
			{
				cout<<arr[i];
			}else
			{
				cout<<","<<arr[i];
			}
		}

		cout<<")";
		cout<<endl;
	}

private:
	vector<double> arr;
	int clusterID;
};


/************************************************************************/
/*                       @Class	Cluster                                 */
/*						 @brief the representtation of one cluster      */
/************************************************************************/
class Cluster
{
public :

	//add new pt-ID belong to 
	void addMember(int vctId)
	{
		member.insert(vctId);
	}

	//reverse pt-ID
	void deleMember(int vctId)
	{
		member.erase(vctId);
	}

	//judege whehter ptID belong the cluster or not
	bool isExistVct(int pt)
	{
		if(member.find(pt)!=member.end())
		{
			return true;
		}else
		{
			return false;
		}
	}

	//get the size of cluster
	int size()
	{

		return member.size();
	}

	//get the center of cluster
	Vector getCenter()
	{
		return vct;
	}

	//set the center of cluster(vct.size() == 0)
	void setCenter(const Vector& vct)
	{
		for(size_t i=0;i<vct.size();i++)
		{
			this->vct.push_back(vct[i]);
		}
	}

	//update the center of cluster (vct.size() !=0)
	void updateCenter(const Vector& vct)
	{
		for(size_t i=0;i<vct.size();i++)
		{
			this->vct[i] = vct[i];
		}
	}

	//reset the center when new-cluster formed
	void resetCenter(const vector<Vector>& pt)
	{
		Vector vctTmp;

		for(size_t i =0;i<vct.size();i++)
		{
			vctTmp.push_back(0);
		}

		for(set<int,less<int> >::iterator sIter = member.begin();sIter!=member.end();sIter++)
		{
			for(size_t j =0;j<vct.size();j++)
			{
				vctTmp[j] +=(pt[*sIter])[j];
			}
		}

		for(i=0;i<vct.size();i++)
		{
			vctTmp[i] = vctTmp[i] / member.size();
		}

		updateCenter(vctTmp);
	}

	//get the detail-information of the cluster
	void info_Cluster(vector<Vector>& pt= vector<Vector>())
	{
		cout<<"The center of the Cluster is :";
		vct.info_Vct();
		cout<<endl;

		for(set<int,less<int> > ::iterator sIter = member.begin();sIter!=member.end();sIter++)
		{
			if(pt.size()==0)
			{
				cout<<*sIter<<endl;
			}
			else
			{
				pt[*sIter].info_Vct();
			}
		}

		cout<<endl;
	}

	
private:
	Vector vct;
	set<int,less<int> >  member;
};


/************************************************************************/
/*                        @Class	Kmeans                              */
/*						  @brief    the structure of KMeans-Algorithm   */
/************************************************************************/
class KMeans
{
public:

	KMeans(const string& inputFile)
	{
		fInputFile.open(inputFile.c_str(),ios::in);

		if(fInputFile.is_open())
		{
			fInputFile>>volumn>>dimen>>kCluster;
		}

		string strline;

		getline(fInputFile,strline);


		volumn = 0;
		while(getline(fInputFile,strline))
		{
			volumn ++;
			istringstream sfin(strline);
			double data;

			Vector vctPt;

			while(sfin>>data)
			{
				vctPt.push_back(data);
			}

			totalPt.push_back(vctPt);
		}
	}

	//init the clusters 
	void initCluster()
	{
		for(int i=0;i<kCluster;i++)
		{
			Cluster clusterPt;

			clusterPt.addMember(i);
			clusterPt.setCenter(totalPt[i]);

			totalPt[i].setClusterID(i);

			totalCluster.push_back(clusterPt);

		}
	}

	//set the max-iterater number
	void setStopCount(int stopCount)
	{
		this->stopCount = stopCount;
	}


	//main-run 
	void run(int MAXTOP=100)
	{
		initCluster();

		setStopCount(MAXTOP);
		
		bool flag = false;

		int count ;

		//iterator to find
		while(!flag && count++<stopCount)
		{
			flag = true;

			//iterator to update the cluster
			for(size_t ptIter =0;ptIter< volumn;ptIter++)
			{
				size_t minIter = 0;
				double minDist = totalPt[ptIter].dist(totalCluster[0].getCenter());
				for(size_t cluIter =1;cluIter<kCluster;cluIter++)
				{
					double tmpDist = totalPt[ptIter].dist(totalCluster[cluIter].getCenter());

					if(tmpDist < minDist)
					{
						minDist = tmpDist;
						minIter = cluIter;
					}
				}

				//update
				if(!(totalCluster[minIter].isExistVct(ptIter)))
				{
					flag = false;
					int clusterID = totalPt[ptIter].getClusterID();

					if(clusterID!=-1)
					{
						totalCluster[clusterID].deleMember(ptIter);
					}

					totalPt[ptIter].setClusterID(ptIter);
					totalCluster[minIter].addMember(ptIter);
				}

			}

			for(size_t i =0;i<totalCluster.size();i++)
			{
				totalCluster[i].resetCenter(totalPt);
			}
		}

	}

	//get the final cluster
	void getClusters(bool flag=false)
	{
		for(size_t cluIter = 0;cluIter < totalCluster.size();cluIter++)
		{
			if(!flag)
			{
				totalCluster[cluIter].info_Cluster();
			}else
			{
				totalCluster[cluIter].info_Cluster(totalPt);
			}

		}
	}
private:
	vector<Vector> totalPt;
	vector<Cluster> totalCluster;

	ifstream fInputFile;

	int dimen;
	int volumn;
	int kCluster;

	int stopCount;

private :

};

int main()
{	/*the structure of the data
        3 3 4
        1 2 4
        2 3 4
        1 2 3
       */
	KMeans kmeans("./data");
	kmeans.run();
	kmeans.getClusters(true);
	return 0;
}

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值