用纯java实现一个随机生成点的k-means聚类算法,附带散点图输出结果
k均值聚类算法的思想很简单,就是给定一个数据点集合和需要的聚类数目k,k由用户指定,根据距离函数反复把数据分入k个聚类中。
我用java简单实现了二维向量在平面随即撒点,之后随机选两个点作为两个聚类的中心,根据欧氏距离的判断哪些点属于哪些类。之后将每个类分别计算几何中心,随后将全体向量重新归类,反复进行多次直到两个中心点不在移动,这两个聚类就完成了。随后将两个聚类绘制成散点图表示出来。
不过存在几个问题:
- 随机点生成在整个平面,数据本身不存在聚类的特性,针对这个问题,在生成随机点的时候加一个偏置,人工生成聚类数据。
- 暂时没想到怎么将聚类的方法抽象化,当前代码是在main方法里实现的,因为感觉将多个聚类包装成集合返回有点不太合适,希望大家可以提供更优雅的代码
- 对于更高维的数据,部分代码需要重构,高效应对更高维数据的代码或许像上一点一样会产生多一个集合,有点浪费存储,同样不是很优雅。
代码结构
使用ArrayList保存向量组,向量的元素为double类型。中心点即向量,
ArrayList<Vector<Double>> vectorList ;
Vector<Double> center;
使用方法生成随即向量组,随即生成数据的时候将数据分为两拨,分别加减一个偏置值30,这里也可以设置一个形参进行调整
public static ArrayList<Vector<Double>> generateVectorGroup(int vectorNum, int dim){
ArrayList<Vector<Double>> vectorList = new ArrayList<Vector<Double>>();
//随机填充向量组
for(int i=0;i<vectorNum/2;i++) {//生成第i个向量
Vector<Double> vTemp = new Vector<>();
for(int j=0;j<dim;j++) {//向第i个向 量填充元素
double randomTemp = Math.random()*100-30;
vTemp.add((randomTemp)<0?(randomTemp+Math.random()*40):(randomTemp));
}vectorList.add(vTemp);//讲向量加入向量组中
}
//增加偏置因子让数据更好的显示类别
for(int i=0;i<vectorNum/2;i++) {//生成第i个向量
Vector<Double> vTemp = new Vector<>();
for(int j=0;j<dim;j++) {//向第i个向 量填充元素
double randomTemp = Math.random()*100+30;
vTemp.add((randomTemp)>100?(randomTemp-Math.random()*40):(randomTemp));
}vectorList.add(vTemp);//讲向量加入向量组中
}
return vectorList;
}
定义一个计算两向量(二维点)距离的方法:
//计算向量的欧式距离
public static double calculationVectorDistance(Vector<Double> v1,Vector<Double> v2, int dim) {
double distance = 0.0;
for(int i=0;i<dim;i++) {
distance += Math.pow((v1.get(i)-v2.get(i)),2);
}
return Math.sqrt(distance);
}
定义一个计算向量组中心的方法,分别计算x和y的均值,需要注意,操作是按列进行的,特别注意两层循环遍历变量的范围
//计算聚类的中心
public static Vector<Double> computeCenter(ArrayList<Vector<Double>> cluster, int dim){
Vector<Double> center = new Vector<>();
for(int i=0;i<dim;i++) {//遍历维数,纵向进行
double calculate = 0.0;
for(int j=0;j<cluster.size();j++) {//遍历向量
calculate+=cluster.get(j).get(i);
}
center.add(calculate/cluster.size());
}
return center;
}
接下来就是main方法,给定点数和维度,调用生成向量组的方法,随机选取两个中心进行迭代计算:
需要注意循环条件的判断。每轮计算可以输出当前的中心和距离等数值以供观察,这里没放,在文后的代码下载链接里有完整代码
int turn = 0;
//开始聚类 使用do-while,直到中心收敛到某个范围内停止计算
//每次迭代重新计算所有点和中心的距离,每次开始前清空聚类
do {
//更新临时节点作为对比
center1Temp.clear();
center2Temp.clear();
for(int i=0;i<2;i++) {
center1Temp.add(center1.get(i));
center2Temp.add(center2.get(i));
}
cluster1.clear();
cluster2.clear();
for(int i=0;i<vectorList.size();i++) {
//计算向量组中所有点到本轮中心点的距离,根据距离差分为两类
double distance1 = calculationVectorDistance(vectorList.get(i),center1,dim);
double distance2 = calculationVectorDistance(vectorList.get(i),center2,dim);
if(distance1>=distance2) {
//该向量属于第二个类
cluster2.add(vectorList.get(i));
}else {
cluster1.add(vectorList.get(i));
}
}//聚类完成
//计算每个聚类的中心
center1= computeCenter(cluster1,dim);
center2 = computeCenter(cluster2,dim);
turn++;
}while(calculationVectorDistance(center1Temp, center1,dim)>= threshold && calculationVectorDistance(center2Temp, center2,dim)>= threshold);
接下来使用javafx的包进行图像显示
for (int i=0;i<cluster1.size();i++) {
//括号内为横纵坐标
aSeries.getData().add(new XYChart.Data(cluster1.get(i).get(0), cluster1.get(i).get(1)));
}
for (int i=0;i<cluster2.size();i++) {
bSeries.getData().add(new XYChart.Data(cluster2.get(i).get(0), cluster2.get(i).get(1)));
}
aCenter.getData().add(new XYChart.Data(center1.get(0), center1.get(1)));
bCenter.getData().add(new XYChart.Data(center2.get(0), center2.get(1)));
answer.addAll(aSeries, bSeries,aCenter,bCenter);
运行结果
可以从每轮的输出中看到一步步的距离和中心的变动
javaFX绘图工具包下载地址
代码下载链接:
链接