本文转载自http://www.cnblogs.com/zhangchaoyang/articles/2181869.html。并对其内容进行了补充和完善,使代码可以直接运行,运算的原始数据由随机数产生。
图示为3个簇,1000个二维变量的分类结果
主程序:
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
public class KMeans {
int k; //簇数
double mu; //迭代终止条件,当各个新质心相对于老质心偏移量小于mu时终止迭代
double[][] center; //上一次各簇质心的位置
int repeat; //存放每次运行的满意度
double[] crita; //存放满意度
public KMeans(int k,double mu,int repeat,int len){
this.k=k;
this.mu=mu;
this.repeat=repeat;
center=new double[k][];
for(int i=0;i<k;i++)
center[i]=new double[len];
crita=new double[repeat];
}
//初始化k个质心,每个质心是len维德向量
public void initCenter(int len,ArrayList<DataObject> objects){
Random random=new Random(System.currentTimeMillis());
int[] count=new int[k];
Iterator<DataObject> iter=objects.iterator();
while(iter.hasNext()){
DataObject object=iter.next();
int id=random.nextInt(100000)%k;
count[id]++;
for(int i=0;i<len;i++){
center[id][i]+=object.getVector()[i];
object.setCid(id);
}
}
for(int i=0;i<k;i++)//计算质心
for(int j=0;j<len;j++)
center[i][j]/=count[i];
}
//将各点进行归类
public void classify(ArrayList<DataObject> objects){
Iterator<DataObject> iter=objects.iterator();
while(iter.hasNext()){
DataObject object=iter.next();
double[] vector=object.getVector();
int len=vector.length;
int index=0;
double neardist=Double.MAX_VALUE;
for(int i=0;i<k;i++){
double dist=Global.calEuraDist(vector, center[i], len);
if(dist<neardist){
neardist=dist;
index=i;
}
}
object.setCid(index);
}
}
//重新计算质心
public boolean calNewCenter(ArrayList<DataObject> objects,int len){
boolean end=true;
int[] count=new int[k];
double[][] sum=new double[k][];
for(int i=0;i<k;i++)
sum[i]=new double[len];
Iterator<DataObject> iter=objects.iterator();
while(iter.hasNext()){
DataObject object=iter.next();
int id=object.getCid();
count[id]++;
for(int i=0;i<len;i++)
sum[id][i]+=object.getVector()[i];
}
for(int i=0;i<k;i++){
if(count[i]!=0){
for(int j=0;j<len;j++)
sum[i][j]/=count[i];
}
else{
int a=(i+1)%k;
int b=(i+3)%k;
int c=(i+5)%k;
for(int j=0;j<len;j++)
center[i][j]=(center[a][j]+center[b][j]+center[c][j])/3;
}
}
for(int i=0;i<this.k;i++){
if(Global.calEuraDist(sum[i],center[i],len)>mu){
end=false;
break;
}
}
if(end==false){
for(int i=0;i<this.k;i++){
for(int j=0;j<len;j++)
center[i][j]=sum[i][j];
}
}
return end;
}
//计算满意度
public double getSati(ArrayList<DataObject> objects,int len){
double satisfy=0;
int[] count=new int[k];
double[] ss=new double[k];
Iterator<DataObject> iter=objects.iterator();
while(iter.hasNext()){
DataObject object=iter.next();
int id=object.getCid();
count[id]++;
for(int i=0;i<len;i++)
ss[id]+=Math.pow(object.getVector()[i]-center[id][i],2);
}
for(int i=0;i<k;i++)
satisfy+=count[i]*ss[i];
return satisfy;
}
//迭代过程
public double run(int round,DataSource datasource,int len) throws IOException{
System.out.println("第"+round+"次运行");
File[] files=new File[this.k];
initCenter(len,datasource.getObjects());
classify(datasource.getObjects());
while(!calNewCenter(datasource.getObjects(),len)){
classify(datasource.getObjects());
}
ArrayList<DataObject> objects=datasource.getObjects();
Iterator<DataObject> it=objects.iterator();
for(int i=0;i<this.k;i++)
files[i]=new File("out"+i+".txt");
FileWriter[] filewriter=new FileWriter[this.k];
for(int i=0;i<this.k;i++)
filewriter[i]=new FileWriter(files[i]);
while(it.hasNext()){
DataObject tmp=it.next();
int pos=tmp.getCid();
for(int i=0;i<tmp.getVector().length;i++){
filewriter[pos].write(Double.toString(tmp.getVector()[i]));
filewriter[pos].write(" ");
}
filewriter[pos].write("\n");
}
for(int i=0;i<this.k;i++)
filewriter[i].close();
double ss=getSati(datasource.getObjects(),len);
System.out.println("加权方差:"+ss);
return ss;
}
public static void main(String[] args) throws IOException{
int k=3;
double mu= 1E-10;
int repeat=10;
int len=2;
int node=1000;
DataSource resource=new DataSource(node,len);
KMeans km=new KMeans(k,mu,repeat,len);
int index=0;
double minsa=Double.MAX_VALUE;
for(int i=0;i<km.repeat;i++){
double ss=km.run(i,resource,len);
if(ss<minsa){
minsa=ss;
index=i;
}
}
System.out.println("最好的结果是第"+index+"次。");
}
}
DataSource
import java.io.*;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.Random;
import java.util.Scanner;
class DataSource {
int node; //number of nodes
int len; //dimensionality
ArrayList<DataObject> objects;
public DataSource(int node,int len) throws IOException{
this.node=node;
this.len=len;
initData();
objects=new ArrayList<DataObject>();
Scanner scanner=new Scanner(new FileInputStream("C:/Users/YJ/workspace/KMeans/o.txt"));
for(int i=0;i<node;i++){
double[] data=new double[this.len];
for(int j=0;j<this.len;j++){
data[j]=scanner.nextDouble();
}
objects.add(new DataObject(data,len));
}
scanner.close();
}
//assign data to objects
public ArrayList<DataObject> getObjects(){
return objects;
}
//initialize data
void initData() throws IOException{
Random random = new Random(System.currentTimeMillis());
File out = new File("o.txt");
FileWriter writer = new FileWriter(out);
for(int i=0;i<node;i++){
for(int j=0;j<len;j++){
double t1=random.nextDouble()*1000;
DecimalFormat fnum=new DecimalFormat("0.00");
writer.write(fnum.format(t1));
writer.write(" ");
}
writer.write("\n");
}
writer.close();
}
}
DataObject
class DataObject{
double [] data; //鍚勭淮鏁版嵁
int len;
int index; //褰掑睘绗嚑涓皣
public DataObject(double[] vector,int len){
data=new double[len];
System.arraycopy(vector, 0, data, 0, vector.length);
this.len=len;
index=0;
}
public double[] getVector(){
return data;
}
public void setCid(int index){
this.index=index;
}
public int getCid(){
return index;
}
}
Global
class Global {
static public double calEuraDist(double[] vector,double[] center,int len){
double sum=0;
for(int i=0;i<len;i++){
sum+=Math.pow(vector[i]-center[i],2);
}
sum=Math.sqrt(sum);
return sum;
}
}