聚类算法之kmeans算法java版本

聚类的意思很明确,物以类聚,把类似的事物放在一起。
      聚类算法是web智能中很重要的一步,可运用在社交,新闻,电商等各种应用中,我打算专门开个分类讲解聚类各种算法的java版实现。
     首先介绍kmeans算法。
     kmeans算法的速度很快,性能良好,几乎是应用最广泛的,它需要先指定聚类的个数k,然后根据k值来自动分出k个类别集合。
     举个例子,某某教练在得到全队的数据后,想把这些球员自动分成不同的组别,你得问教练需要分成几个组,他回答你k个,ok可以开始了,在解决这个问题之前有必要详细了解自己需要达到的目的:根据教练给出的k值,呈现出k个组,每个组的队员是相似的。
     首先,我们创建球员类。 

 

01package kmeans;
02  
03   /**
04    * 球员
05     
06    * @author 阿飞哥
07    
08    */
09  public class Player {
10  
11 private int id;
12 private String name;
13  
14 private int age;
15  
16 /* 得分 */
17 @KmeanField
18 private double goal;
19  
20 /* 助攻 */
21 //@KmeanField
22 private double assists;
23  
24 /* 篮板 */
25 //@KmeanField
26 private double backboard;
27  
28 /* 抢断 */
29 //@KmeanField
30 private double steals;
31  
32 public int getId() {
33  return id;
34 }
35  
36 public void setId(int id) {
37  this.id = id;
38 }
39  
40 public String getName() {
41  return name;
42 }
43  
44 public void setName(String name) {
45  this.name = name;
46 }
47  
48 public int getAge() {
49  return age;
50 }
51  
52 public void setAge(int age) {
53  this.age = age;
54 }
55  
56 public double getGoal() {
57  return goal;
58 }
59  
60 public void setGoal(double goal) {
61  this.goal = goal;
62 }
63  
64 public double getAssists() {
65  return assists;
66 }
67  
68 public void setAssists(double assists) {
69  this.assists = assists;
70 }
71  
72 public double getBackboard() {
73  return backboard;
74 }
75  
76 public void setBackboard(double backboard) {
77  this.backboard = backboard;
78 }
79  
80 public double getSteals() {
81  return steals;
82 }
83  
84 public void setSteals(double steals) {
85  this.steals = steals;
86 }
87  
88   
89}

        

   @KmeanField这个注解是自定义的,用来标示这个属性是否是算法需要的维度。
代码如下 

01package kmeans;
02  
03import java.lang.annotation.ElementType;
04import java.lang.annotation.Retention;
05import java.lang.annotation.RetentionPolicy;
06import java.lang.annotation.Target;
07  
08/**
09 * 在对象的属性上标注此注释,
10 * 表示纳入kmeans算法,仅支持数值类属性
11 * @author 阿飞哥
12 */
13@Retention(RetentionPolicy.RUNTIME)
14@Target(ElementType.FIELD)
15public @interface KmeanField {
16}

接下来看看最核心的kmeans算法,具体实现过程如下:
1,初始化k个聚类中心
2,计算出每个对象跟这k个中心的距离(相似度计算,这个下面会提到),假如x这个对象跟y这个中心的距离最小(相似度最大),那么x属于y这个中心。这一步就可以得到初步的k个聚类
3,在第二步得到的每个聚类分别计算出新的聚类中心,和旧的中心比对,假如不相同,则继续第2步,直到新旧两个中心相同,说明聚类不可变,已经成功

实现代码如下:

001package kmeans;
002  
003import java.lang.annotation.Annotation;
004import java.lang.reflect.Field;
005import java.lang.reflect.Method;
006import java.util.ArrayList;
007import java.util.List;
008  
009/**
010 
011 * @author 阿飞哥
012 
013 */
014public class Kmeans<T> {
015  
016 /**
017  * 所有数据列表
018  */
019 private List<T> players = new ArrayList<T>();
020  
021 /**
022  * 数据类别
023  */
024 private Class<T> classT;
025  
026 /**
027  * 初始化列表
028  */
029 private List<T> initPlayers;
030  
031 /**
032  * 需要纳入kmeans算法的属性名称
033  */
034 private List<String> fieldNames = new ArrayList<String>();
035  
036 /**
037  * 分类数
038  */
039 private int k = 1;
040  
041 public Kmeans() {
042  
043 }
044  
045 /**
046  * 初始化列表
047  
048  * @param list
049  * @param k
050  */
051 public Kmeans(List<T> list, int k) {
052  this.players = list;
053  this.k = k;
054  T t = list.get(0);
055  this.classT = (Class<T>) t.getClass();
056  Field[] fields = this.classT.getDeclaredFields();
057  for (int i = 0; i < fields.length; i++) {
058   Annotation kmeansAnnotation = fields[i]
059     .getAnnotation(KmeanField.class);
060   if (kmeansAnnotation != null) {
061    fieldNames.add(fields[i].getName());
062   }
063  
064  }
065  
066  initPlayers = new ArrayList<T>();
067  for (int i = 0; i < k; i++) {
068   initPlayers.add(players.get(i));
069  }
070 }
071  
072 public List<T>[] comput() {
073  List<T>[] results = new ArrayList[k];
074  
075  boolean centerchange = true;
076  while (centerchange) {
077   centerchange = false;
078   for (int i = 0; i < k; i++) {
079    results[i] = new ArrayList<T>();
080   }
081   for (int i = 0; i < players.size(); i++) {
082    T p = players.get(i);
083    double[] dists = new double[k];
084    for (int j = 0; j < initPlayers.size(); j++) {
085     T initP = initPlayers.get(j);
086     /* 计算距离 */
087     double dist = distance(initP, p);
088     dists[j] = dist;
089    }
090  
091    int dist_index = computOrder(dists);
092    results[dist_index].add(p);
093   }
094  
095   for (int i = 0; i < k; i++) {
096    T player_new = findNewCenter(results[i]);
097    T player_old = initPlayers.get(i);
098    if (!IsPlayerEqual(player_new, player_old)) {
099     centerchange = true;
100     initPlayers.set(i, player_new);
101    }
102  
103   }
104  
105  }
106  
107  return results;
108 }
109  
110 /**
111  * 比较是否两个对象是否属性一致
112  
113  * @param p1
114  * @param p2
115  * @return
116  */
117 public boolean IsPlayerEqual(T p1, T p2) {
118  if (p1 == p2) {
119   return true;
120  }
121  if (p1 == null || p2 == null) {
122   return false;
123  }
124  
125    
126  
127  boolean flag = true;
128  try {
129   for (int i = 0; i < fieldNames.size(); i++) {
130    String fieldName=fieldNames.get(i);
131    String getName = "get"
132      + fieldName.substring(0, 1).toUpperCase()
133      + fieldName.substring(1);    
134    Object value1 = invokeMethod(p1,getName,null);
135    Object value2 = invokeMethod(p2,getName,null);
136    if (!value1.equals(value2)) {
137     flag = false;
138     break;
139    }
140   }
141  } catch (Exception e) {
142   e.printStackTrace();
143   flag = false;
144  }
145  
146  return flag;
147 }
148  
149 /**
150  * 得到新聚类中心对象
151  
152  * @param ps
153  * @return
154  */
155 public T findNewCenter(List<T> ps) {
156  try {
157   T t = classT.newInstance();
158   if (ps == null || ps.size() == 0) {
159    return t;
160   }
161  
162   double[] ds = new double[fieldNames.size()];
163   for (T vo : ps) {
164    for (int i = 0; i < fieldNames.size(); i++) {
165     String fieldName=fieldNames.get(i);
166     String getName = "get"
167       + fieldName.substring(0, 1).toUpperCase()
168       + fieldName.substring(1);
169     Object obj=invokeMethod(vo,getName,null);
170     Double fv=(obj==null?0:Double.parseDouble(obj+""));
171     ds[i] += fv;
172    }
173  
174   }
175  
176   for (int i = 0; i < fieldNames.size(); i++) {
177    ds[i] = ds[i] / ps.size();
178    String fieldName = fieldNames.get(i);
179      
180    /* 给对象设值 */
181    String setName = "set"
182      + fieldName.substring(0, 1).toUpperCase()
183      + fieldName.substring(1);
184  
185    invokeMethod(t,setName,new Class[]{double.class},ds[i]);
186  
187   }
188  
189   return t;
190  } catch (Exception ex) {
191   ex.printStackTrace();
192  }
193  return null;
194  
195 }
196  
197 /**
198  * 得到最短距离,并返回最短距离索引
199  
200  * @param dists
201  * @return
202  */
203 public int computOrder(double[] dists) {
204  double min = 0;
205  int index = 0;
206  for (int i = 0; i < dists.length - 1; i++) {
207   double dist0 = dists[i];
208   if (i == 0) {
209    min = dist0;
210    index = 0;
211   }
212   double dist1 = dists[i + 1];
213   if (min > dist1) {
214    min = dist1;
215    index = i + 1;
216   }
217  }
218  
219  return index;
220 }
221  
222 /**
223  * 计算距离(相似性) 采用欧几里得算法
224  
225  * @param p0
226  * @param p1
227  * @return
228  */
229 public double distance(T p0, T p1) {
230  double dis = 0;
231  try {
232  
233   for (int i = 0; i < fieldNames.size(); i++) {
234    String fieldName = fieldNames.get(i);
235    String getName = "get"
236      + fieldName.substring(0, 1).toUpperCase()
237      + fieldName.substring(1);
238      
239    Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");
240    Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");
241    dis += Math.pow(field0Value - field1Value, 2); 
242   }
243    
244  } catch (Exception ex) {
245   ex.printStackTrace();
246  }
247  return Math.sqrt(dis);
248  
249 }
250   
251 /*------公共方法-----*/
252 public Object invokeMethod(Object owner, String methodName,Class[] argsClass,
253   Object... args) {
254  Class ownerClass = owner.getClass();
255  try {
256   Method method=ownerClass.getDeclaredMethod(methodName,argsClass);
257   return method.invoke(owner, args);
258  } catch (SecurityException e) {
259   e.printStackTrace();
260  } catch (NoSuchMethodException e) {
261   e.printStackTrace();
262  } catch (Exception ex) {
263   ex.printStackTrace();
264  }
265  
266  return null;
267 }
268  
269}

最后咱们测试一下:

01package kmeans;
02  
03import java.util.ArrayList;
04import java.util.List;
05import java.util.Random;
06  
07public class TestMain {
08  
09 public static void main(String[] args) {
10       List<Player> listPlayers=new ArrayList<Player>();
11          
12        for(int i=0;i<15;i++){
13           
14         Player p1=new Player();
15         p1.setName("afei-"+i);
16         p1.setAssists(i);
17         p1.setBackboard(i);
18           
19         //p1.setGoal(new Random(100*i).nextDouble());
20         p1.setGoal(i*10);
21         p1.setSteals(i);
22         //listPlayers.add(p1); 
23        }
24          
25        Player p1=new Player();
26        p1.setName("afei1");
27        p1.setGoal(1);
28        p1.setAssists(8);
29        listPlayers.add(p1);
30         
31        Player p2=new Player();
32        p2.setName("afei2");
33        p2.setGoal(2);
34        listPlayers.add(p2);
35          
36         Player p3=new Player();
37        p3.setName("afei3");
38        p3.setGoal(3);
39        listPlayers.add(p3);
40          
41         Player p4=new Player();
42        p4.setName("afei4");
43        p4.setGoal(7);
44        listPlayers.add(p4);
45          
46         Player p5=new Player();
47        p5.setName("afei5");
48        p5.setGoal(8);
49        listPlayers.add(p5);
50          
51         Player p6=new Player();
52        p6.setName("afei6");
53        p6.setGoal(25);
54        listPlayers.add(p6);
55          
56         Player p7=new Player();
57        p7.setName("afei7");
58        p7.setGoal(26);
59        listPlayers.add(p7);
60          
61         Player p8=new Player();
62        p8.setName("afei8");
63        p8.setGoal(27);
64        listPlayers.add(p8);
65          
66         Player p9=new Player();
67        p9.setName("afei9");
68        p9.setGoal(28);
69        listPlayers.add(p9);
70          
71          
72  Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,3);
73  List<Player>[] results = kmeans.comput();
74  for (int i = 0; i < results.length; i++) {
75   System.out.println("===========类别" + (i + 1) + "================");
76   List<Player> list = results[i];
77   for (Player p : list) {
78    System.out.println(p.getName() + "--->"
79      + p.getGoal() + "," + p.getAssists() + ","
80      + p.getSteals() + "," + p.getBackboard());
81   }
82  }
83    
84    
85    
86        
87 }
88  
89}

结果如下

  这个里面涉及到相似度算法,事实证明欧几里得距离算法的实践效果是最优的。
  最后说说kmeans算法的不足:可以看到只能针对数字类型的属性(维),对于其他类型的除非选定合适的数值度量

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值