/************************************************************************/
/* @Author Amiber */
/* @Date 2012-11-25 */
/* @brief: KMeans-Algorithm */
/************************************************************************/
#pragma warning(disable:4786)
#include<iostream> //for cout,...
#include<algorithm> //for algorithms
#include<string> // for
#include<set> //for set<int,less<int> >
#include<vector> //for vector<>
#include<cmath> //for sqrt,pow
#include<iterator> //for set<int,less<int> >::iterator
#include<fstream> //for ifstream
#include<sstream> // for istringstream
using namespace std;
/************************************************************************/
/* @Class Vector */
/* @brief the representation of point-vector */
/************************************************************************/
class Vector
{
public:
Vector()
{
clusterID = -1;
}
//achieve Vector[]
double& operator [] (size_t index)
{
return arr[index];
}
//achieve Vector[] const
double operator [] (size_t index) const
{
return arr[index];
}
//update the value in index
void setValue(int index,double value)
{
arr[index] = value;
}
//add the new-value
void push_back(double value)
{
arr.push_back(value);
}
//get the dimen
int size() const
{
return arr.size();
}
//calculate the dist between another pt and self
double dist(const Vector& pt)
{
double sum = 0;
for(size_t i=0;i<pt.size();i++)
{
sum +=pow(pt[i]-arr[i],2.0);
}
return sqrt(sum);
}
//set which clusterID which Pt belongs
void setClusterID(int clusterID)
{
this->clusterID = clusterID;
}
//get the Belonged ClusterID
int getClusterID()
{
return clusterID;
}
//the detail information of pt
void info_Vct()
{
cout<<"(";
for(size_t i=0;i<arr.size();i++)
{
if(!i)
{
cout<<arr[i];
}else
{
cout<<","<<arr[i];
}
}
cout<<")";
cout<<endl;
}
private:
vector<double> arr;
int clusterID;
};
/************************************************************************/
/* @Class Cluster */
/* @brief the representtation of one cluster */
/************************************************************************/
class Cluster
{
public :
//add new pt-ID belong to
void addMember(int vctId)
{
member.insert(vctId);
}
//reverse pt-ID
void deleMember(int vctId)
{
member.erase(vctId);
}
//judege whehter ptID belong the cluster or not
bool isExistVct(int pt)
{
if(member.find(pt)!=member.end())
{
return true;
}else
{
return false;
}
}
//get the size of cluster
int size()
{
return member.size();
}
//get the center of cluster
Vector getCenter()
{
return vct;
}
//set the center of cluster(vct.size() == 0)
void setCenter(const Vector& vct)
{
for(size_t i=0;i<vct.size();i++)
{
this->vct.push_back(vct[i]);
}
}
//update the center of cluster (vct.size() !=0)
void updateCenter(const Vector& vct)
{
for(size_t i=0;i<vct.size();i++)
{
this->vct[i] = vct[i];
}
}
//reset the center when new-cluster formed
void resetCenter(const vector<Vector>& pt)
{
Vector vctTmp;
for(size_t i =0;i<vct.size();i++)
{
vctTmp.push_back(0);
}
for(set<int,less<int> >::iterator sIter = member.begin();sIter!=member.end();sIter++)
{
for(size_t j =0;j<vct.size();j++)
{
vctTmp[j] +=(pt[*sIter])[j];
}
}
for(i=0;i<vct.size();i++)
{
vctTmp[i] = vctTmp[i] / member.size();
}
updateCenter(vctTmp);
}
//get the detail-information of the cluster
void info_Cluster(vector<Vector>& pt= vector<Vector>())
{
cout<<"The center of the Cluster is :";
vct.info_Vct();
cout<<endl;
for(set<int,less<int> > ::iterator sIter = member.begin();sIter!=member.end();sIter++)
{
if(pt.size()==0)
{
cout<<*sIter<<endl;
}
else
{
pt[*sIter].info_Vct();
}
}
cout<<endl;
}
private:
Vector vct;
set<int,less<int> > member;
};
/************************************************************************/
/* @Class Kmeans */
/* @brief the structure of KMeans-Algorithm */
/************************************************************************/
class KMeans
{
public:
KMeans(const string& inputFile)
{
fInputFile.open(inputFile.c_str(),ios::in);
if(fInputFile.is_open())
{
fInputFile>>volumn>>dimen>>kCluster;
}
string strline;
getline(fInputFile,strline);
volumn = 0;
while(getline(fInputFile,strline))
{
volumn ++;
istringstream sfin(strline);
double data;
Vector vctPt;
while(sfin>>data)
{
vctPt.push_back(data);
}
totalPt.push_back(vctPt);
}
}
//init the clusters
void initCluster()
{
for(int i=0;i<kCluster;i++)
{
Cluster clusterPt;
clusterPt.addMember(i);
clusterPt.setCenter(totalPt[i]);
totalPt[i].setClusterID(i);
totalCluster.push_back(clusterPt);
}
}
//set the max-iterater number
void setStopCount(int stopCount)
{
this->stopCount = stopCount;
}
//main-run
void run(int MAXTOP=100)
{
initCluster();
setStopCount(MAXTOP);
bool flag = false;
int count ;
//iterator to find
while(!flag && count++<stopCount)
{
flag = true;
//iterator to update the cluster
for(size_t ptIter =0;ptIter< volumn;ptIter++)
{
size_t minIter = 0;
double minDist = totalPt[ptIter].dist(totalCluster[0].getCenter());
for(size_t cluIter =1;cluIter<kCluster;cluIter++)
{
double tmpDist = totalPt[ptIter].dist(totalCluster[cluIter].getCenter());
if(tmpDist < minDist)
{
minDist = tmpDist;
minIter = cluIter;
}
}
//update
if(!(totalCluster[minIter].isExistVct(ptIter)))
{
flag = false;
int clusterID = totalPt[ptIter].getClusterID();
if(clusterID!=-1)
{
totalCluster[clusterID].deleMember(ptIter);
}
totalPt[ptIter].setClusterID(ptIter);
totalCluster[minIter].addMember(ptIter);
}
}
for(size_t i =0;i<totalCluster.size();i++)
{
totalCluster[i].resetCenter(totalPt);
}
}
}
//get the final cluster
void getClusters(bool flag=false)
{
for(size_t cluIter = 0;cluIter < totalCluster.size();cluIter++)
{
if(!flag)
{
totalCluster[cluIter].info_Cluster();
}else
{
totalCluster[cluIter].info_Cluster(totalPt);
}
}
}
private:
vector<Vector> totalPt;
vector<Cluster> totalCluster;
ifstream fInputFile;
int dimen;
int volumn;
int kCluster;
int stopCount;
private :
};
int main()
{ /*the structure of the data
3 3 4
1 2 4
2 3 4
1 2 3
*/
KMeans kmeans("./data");
kmeans.run();
kmeans.getClusters(true);
return 0;
}
STL-KMeans-Algorithm
最新推荐文章于 2024-09-12 19:02:20 发布