k-means算法思想
k-means算法是一种迭代求解的聚类分析算法,其功能是将平面上的各个点按照聚集程度进行分类。它的实现过程有如下几个步骤:
1.随机选取K个点作为初始的聚类中心。
2.分别计算每个点与k个聚类中心之间的距离,把每个点分配给距离它最近的聚类中心。聚类中心以及分配给它们的点就代表一个聚类,总共会形成k个聚类。
3.遍历每一个聚类,算出聚类的中心位置,作为新的聚类中心。
4.由于聚类中心的位置改变了,每个点与聚类中心的距离也会相应改变,所以需要循环第2、3步,直到计算后的聚类中心没有发生改变。
.
代码实现
设置k值为2,以(20,20)和(40,40)为中心随机生成20个点
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
public class Point
{
int x;
int y;
Point(){
x=0;y=0;
}
int getX(){
return x;
}
int getY(){
return y;
}
void set(int a,int b){
x=a;y=b;
}
void print(){
System.out.println("("+x+","+y+")");
}
void random1(){
int a=(int)(1+Math.random()*40);
int b=(int)(1+Math.random()*40);
this.set(a,b);
}
void random2(){
int a=(int)(20+Math.random()*40);
int b=(int)(20+Math.random()*40);
this.set(a,b);
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
Point point = (Point) o;
return x == point.x &&
y == point.y;
}
@Override
public int hashCode() {
return Objects.hash(x, y);
}
public static Point update(ArrayList<Point> list){
Point t=new Point();
int count=0;
for(int j=0;j<list.size();j++){
count++;
t.x+=list.get(j).getX();
t.y+=list.get(j).getY();
}
t.x/=count;
t.y/=count;
return t;
}
public static void main(String []args){
Point[] p = new Point[20];
for(int i=0;i<20;i++){
p[i] = new Point();
}
//设置两个中心点
Point k1=new Point();
Point k2=new Point();
k1.set(20,20);
k2.set(40,40);
//随机生成横纵坐标0~40
for(int i=0;i<10;i++ ){
p[i].random1();
}
//随机生成横纵坐标20~60
for(int i=10;i<20;i++ ){
p[i].random2();
}
System.out.println("\n");
Point a=new Point();
Point b=new Point();
//迭代直到中心点不变
while(true){
a.set(k1.getX(),k1.getY());
b.set(k2.getX(),k2.getY());
ArrayList<Point> L1 = new ArrayList<>();
ArrayList<Point> L2 = new ArrayList<>();
for(int i=0;i<20;i++){
int d1=(k1.getX()-p[i].getX())*(k1.getX()-p[i].getX())+(k1.getY()-p[i].getY())*(k1.getY()-p[i].getY());
int d2=(k2.getX()-p[i].getX())*(k2.getX()-p[i].getX())+(k2.getY()-p[i].getY())*(k2.getY()-p[i].getY());
if(d1<d2){
L1.add(p[i]);
}else{
L2.add(p[i]);
}
}
System.out.println("k1:");
for(int i=0;i<L1.size();i++){
L1.get(i).print();
}
System.out.println("k2:");
for(int i=0;i<L2.size();i++){
L2.get(i).print();
}
k1=update(L1);
k2=update(L2);
System.out.println("更新后的中心点:");
k1.print();
k2.print();
if(a.equals(k1)&&b.equals(k2)){
System.out.println("中心点位置不变,循环结束");
break;
}
}
}
}
截图
图片