对于K-Means算法想必做机器学习和数据挖掘的广大同胞们已经不再陌生,做为数据挖据的十大经典算法之一,k-Means做聚类分析上有得天独厚的优势。对于其原理进行简单的描述:
k-Means算法是典型的基于距离的聚类算法,采用的是距离作为相似性指标。经过n次迭代后,当中心的位置不在发生变换的时候即是收敛完成。
算法:
1. 从n个文档中随机的选择出k个文档作为质心
2.从剩余的文档中测量出每个文档到质心的距离,并归类到最小质心的一类中
3. 重新计算质心的位置
4.重复2-3步,直到迭代完成。
由以上步骤,可以有java实现K-Means算法。随机产生100个点,设置k=5后进行聚类操作:
1.主函数:
- package KMeans;
- import java.util.ArrayList;
- /**
- * K-Means算法
- * @author Administrator
- *
- */
- public class k_means {
- /**
- * @param args
- */
- public static void main(String[] args) {
- //1.创建二维数组 10x10的数组
- int num_1[]=new int[100];
- int num_2[]=new int[100];
- //随机赋值
- for(int i=0;i<100;i++){
- num_1[i]=(int)( Math.random()*100);
- }
- for(int i=0;i<100;i++){
- num_2[i]=(int)( Math.random()*100);
- }
- // 2.创建点坐标
- ArrayList<pointBean> list=new ArrayList<pointBean>();
- pointBean bean;
- for(int i=0;i<100;i++){
- bean=new pointBean();
- bean.point_x=num_1[i];
- bean.point_y=num_2[i];
- list.add(bean);
- }
- // 执行k-means算法
- getDataKMeans gg=new getDataKMeans();
- gg.setData(list);
- }
- }
- package KMeans;
- public class pointBean {
- int point_x;
- int point_y;
- public int getPoint_x() {
- return point_x;
- }
- public void setPoint_x(int point_x) {
- this.point_x = point_x;
- }
- public int getPoint_y() {
- return point_y;
- }
- public void setPoint_y(int point_y) {
- this.point_y = point_y;
- }
- @Override
- public String toString() {
- return "pointBean [point_x=" + point_x + ", point_y=" + point_y + "]";
- }
- public pointBean(int point_x, int point_y) {
- super();
- this.point_x = point_x;
- this.point_y = point_y;
- }
- public pointBean() {
- super();
- }
- }
- package KMeans;
- import java.util.ArrayList;
- public class getDataKMeans {
- int k=5;//k值
- //第一个中心点x,y
- static double con1_x;
- static double con1_y;
- //第一个中心点x,y
- static double con2_x;
- static double con2_y;
- //第一个中心点x,y
- static double con3_x;
- static double con3_y;
- //第一个中心点x,y
- static double con4_x;
- static double con4_y;
- //第一个中心点x,y
- static double con5_x;
- static double con5_y;
- //创建5个list装各个点
- ArrayList<pointBean> list1=new ArrayList<pointBean>();
- ArrayList<pointBean> list2=new ArrayList<pointBean>();
- ArrayList<pointBean> list3=new ArrayList<pointBean>();
- ArrayList<pointBean> list4=new ArrayList<pointBean>();
- ArrayList<pointBean> list5=new ArrayList<pointBean>();
- public void setData(ArrayList<pointBean> list){
- con1_x=list.get(0).point_x;
- con1_y=list.get(0).point_y;
- con2_x=list.get(1).point_x;
- con2_y=list.get(1).point_y;
- con3_x=list.get(2).point_x;
- con3_y=list.get(2).point_y;
- con4_x=list.get(3).point_x;
- con4_y=list.get(3).point_y;
- con5_x=list.get(4).point_x;
- con5_y=list.get(4).point_y;
- //分别加入list中
- list1.add(list.get(0));
- list2.add(list.get(1));
- list3.add(list.get(2));
- list4.add(list.get(3));
- list5.add(list.get(4));
- //循环操作
- for(int i=5;i<list.size();i++){
- getLength(list.get(i));
- }
- // 打印出对应的中心点 、聚类的值
- System.out.println("-------1-------");
- System.out.println("1的中心点:"+con1_x+" "+con1_y);
- for(int i=0;i<list1.size();i++){
- System.out.println(list1.get(i).point_x+" "+list1.get(i).point_y);
- }
- System.out.println("-------2-------");
- System.out.println("2的中心点:"+con2_x+" "+con2_y);
- for(int i=0;i<list2.size();i++){
- System.out.println(list2.get(i).point_x+" "+list2.get(i).point_y);
- }
- System.out.println("-------3-------");
- System.out.println("3的中心点:"+con3_x+" "+con3_y);
- for(int i=0;i<list3.size();i++){
- System.out.println(list3.get(i).point_x+" "+list3.get(i).point_y);
- }
- System.out.println("-------4-------");
- System.out.println("4的中心点:"+con4_x+" "+con4_y);
- for(int i=0;i<list4.size();i++){
- System.out.println(list4.get(i).point_x+" "+list4.get(i).point_y);
- }
- System.out.println("-------5-------");
- System.out.println("5的中心点:"+con5_x+" "+con5_y);
- for(int i=0;i<list5.size();i++){
- System.out.println(list5.get(i).point_x+" "+list5.get(i).point_y);
- }
- }
- /**
- * 求出每个点到中心点距离
- * @param point
- */
- public void getLength(pointBean point) {
- int x=point.point_x;
- int y=point.point_y;
- double s1=(x-con1_x)*(x-con1_x)+(y-con1_y)*(y-con1_y);
- double s2=(x-con2_x)*(x-con2_x)+(y-con2_y)*(y-con2_y);
- double s3=(x-con3_x)*(x-con3_x)+(y-con3_y)*(y-con3_y);
- double s4=(x-con4_x)*(x-con4_x)+(y-con4_y)*(y-con4_y);
- double s5=(x-con5_x)*(x-con5_x)+(y-con5_y)*(y-con5_y);
- double nn[]={s1,s2,s3,s4,s5};
- // 找出最小的一个
- double temp=nn[0];
- for(int i=1;i<nn.length;i++){
- if(nn[i]<=temp)
- temp=nn[i];
- }
- // 添加点
- if(temp==s1){
- list1.add(point);
- upDataPoint(list1,con1_x,con1_y);
- }
- if(temp==s2){
- list2.add(point);
- upDataPoint(list2,con2_x,con2_x);
- }
- if(temp==s3){
- list3.add(point);
- upDataPoint(list3,con3_x,con3_x);
- }
- if(temp==s4){
- list4.add(point);
- upDataPoint(list4,con4_x,con4_x);
- }
- if(temp==s5){
- list5.add(point);
- upDataPoint(list5,con5_x,con5_x);
- }
- }
- /**
- * 更新中心点坐标
- * @param list
- */
- private void upDataPoint(ArrayList<pointBean> list,double x,double y) {
- double up_x=0;
- double up_y=0;
- for(int i=0;i<list.size();i++){
- up_x+=list.get(i).point_x;
- up_y+=list.get(i).point_y;
- }
- x=up_x/(list.size());
- y=up_y/(list.size());
- }
- }
得到的测试结果:
- -------1-------
- 1的中心点:37.0 80.0
- 37 80
- 54 88
- 10 50
- 45 85
- 40 95
- 51 87
- 47 90
- 42 97
- 30 61
- 20 63
- 60 80
- 37 93
- 47 79
- 37 96
- 58 86
- -------2-------
- 2的中心点:55.0 57.0
- 55 57
- 89 81
- 56 58
- 49 53
- 58 62
- 42 52
- 26 49
- 94 95
- 21 44
- 1 19
- 27 53
- 59 74
- 61 77
- 32 56
- 49 54
- 10 39
- 53 55
- 48 58
- 8 36
- 63 63
- 4 26
- 49 62
- 63 80
- 45 62
- -------3-------
- 3的中心点:79.0 22.0
- 79 22
- 84 41
- 68 17
- 79 2
- 99 33
- 69 11
- 70 29
- 52 8
- 94 25
- 81 8
- 54 20
- 81 32
- 81 34
- 48 2
- 22 1
- 89 27
- 57 18
- 42 11
- 50 6
- 74 28
- 98 27
- 98 36
- -------4-------
- 4的中心点:59.0 51.0
- 59 51
- 64 40
- 21 8
- 5 19
- 29 32
- 62 40
- 7 5
- 16 25
- 53 36
- 28 29
- 33 19
- 80 55
- 50 40
- 98 76
- 81 53
- 23 23
- 92 62
- 85 63
- 65 36
- 48 44
- 25 30
- 11 15
- 97 79
- 16 9
- 60 43
- 59 51
- 67 43
- -------5-------
- 5的中心点:10.0 97.0
- 10 97
- 14 74
- 13 89
- 1 60
- 4 94
- 6 72
- 1 73
- 4 86
- 18 80
- 2 81
- 11 70
- 19 97