聚类的意思很明确,物以类聚,把类似的事物放在一起。
聚类算法是web智能中很重要的一步,可运用在社交,新闻,电商等各种应用中,我打算专门开个分类讲解聚类各种算法的java版实现。
首先介绍kmeans算法。
kmeans算法的速度很快,性能良好,几乎是应用最广泛的,它需要先指定聚类的个数k,然后根据k值来自动分出k个类别集合。
举个例子,某某教练在得到全队的数据后,想把这些球员自动分成不同的组别,你得问教练需要分成几个组,他回答你k个,ok可以开始了,在解决这个问题之前有必要详细了解自己需要达到的目的:根据教练给出的k值,呈现出k个组,每个组的队员是相似的。
首先,我们创建球员类。
01 | package kmeans; |
02 | |
03 | /** |
04 | * 球员 |
05 | * |
06 | * @author 阿飞哥 |
07 | * |
08 | */ |
09 | public class Player { |
10 | |
11 | private int id; |
12 | private String name; |
13 | |
14 | private int age; |
15 | |
16 | /* 得分 */ |
17 | @KmeanField |
18 | private double goal; |
19 | |
20 | /* 助攻 */ |
21 | //@KmeanField |
22 | private double assists; |
23 | |
24 | /* 篮板 */ |
25 | //@KmeanField |
26 | private double backboard; |
27 | |
28 | /* 抢断 */ |
29 | //@KmeanField |
30 | private double steals; |
31 | |
32 | public int getId() { |
33 | return id; |
34 | } |
35 | |
36 | public void setId( int id) { |
37 | this .id = id; |
38 | } |
39 | |
40 | public String getName() { |
41 | return name; |
42 | } |
43 | |
44 | public void setName(String name) { |
45 | this .name = name; |
46 | } |
47 | |
48 | public int getAge() { |
49 | return age; |
50 | } |
51 | |
52 | public void setAge( int age) { |
53 | this .age = age; |
54 | } |
55 | |
56 | public double getGoal() { |
57 | return goal; |
58 | } |
59 | |
60 | public void setGoal( double goal) { |
61 | this .goal = goal; |
62 | } |
63 | |
64 | public double getAssists() { |
65 | return assists; |
66 | } |
67 | |
68 | public void setAssists( double assists) { |
69 | this .assists = assists; |
70 | } |
71 | |
72 | public double getBackboard() { |
73 | return backboard; |
74 | } |
75 | |
76 | public void setBackboard( double backboard) { |
77 | this .backboard = backboard; |
78 | } |
79 | |
80 | public double getSteals() { |
81 | return steals; |
82 | } |
83 | |
84 | public void setSteals( double steals) { |
85 | this .steals = steals; |
86 | } |
87 | |
88 | |
89 | } |
@KmeanField这个注解是自定义的,用来标示这个属性是否是算法需要的维度。
代码如下
01 | package kmeans; |
02 | |
03 | import java.lang.annotation.ElementType; |
04 | import java.lang.annotation.Retention; |
05 | import java.lang.annotation.RetentionPolicy; |
06 | import java.lang.annotation.Target; |
07 | |
08 | /** |
09 | * 在对象的属性上标注此注释, |
10 | * 表示纳入kmeans算法,仅支持数值类属性 |
11 | * @author 阿飞哥 |
12 | */ |
13 | @Retention (RetentionPolicy.RUNTIME) |
14 | @Target (ElementType.FIELD) |
15 | public @interface KmeanField { |
16 | } |
接下来看看最核心的kmeans算法,具体实现过程如下:
1,初始化k个聚类中心
2,计算出每个对象跟这k个中心的距离(相似度计算,这个下面会提到),假如x这个对象跟y这个中心的距离最小(相似度最大),那么x属于y这个中心。这一步就可以得到初步的k个聚类
3,在第二步得到的每个聚类分别计算出新的聚类中心,和旧的中心比对,假如不相同,则继续第2步,直到新旧两个中心相同,说明聚类不可变,已经成功
实现代码如下:
001 | package kmeans; |
002 | |
003 | import java.lang.annotation.Annotation; |
004 | import java.lang.reflect.Field; |
005 | import java.lang.reflect.Method; |
006 | import java.util.ArrayList; |
007 | import java.util.List; |
008 | |
009 | /** |
010 | * |
011 | * @author 阿飞哥 |
012 | * |
013 | */ |
014 | public class Kmeans<T> { |
015 | |
016 | /** |
017 | * 所有数据列表 |
018 | */ |
019 | private List<T> players = new ArrayList<T>(); |
020 | |
021 | /** |
022 | * 数据类别 |
023 | */ |
024 | private Class<T> classT; |
025 | |
026 | /** |
027 | * 初始化列表 |
028 | */ |
029 | private List<T> initPlayers; |
030 | |
031 | /** |
032 | * 需要纳入kmeans算法的属性名称 |
033 | */ |
034 | private List<String> fieldNames = new ArrayList<String>(); |
035 | |
036 | /** |
037 | * 分类数 |
038 | */ |
039 | private int k = 1 ; |
040 | |
041 | public Kmeans() { |
042 | |
043 | } |
044 | |
045 | /** |
046 | * 初始化列表 |
047 | * |
048 | * @param list |
049 | * @param k |
050 | */ |
051 | public Kmeans(List<T> list, int k) { |
052 | this .players = list; |
053 | this .k = k; |
054 | T t = list.get( 0 ); |
055 | this .classT = (Class<T>) t.getClass(); |
056 | Field[] fields = this .classT.getDeclaredFields(); |
057 | for ( int i = 0 ; i < fields.length; i++) { |
058 | Annotation kmeansAnnotation = fields[i] |
059 | .getAnnotation(KmeanField. class ); |
060 | if (kmeansAnnotation != null ) { |
061 | fieldNames.add(fields[i].getName()); |
062 | } |
063 | |
064 | } |
065 | |
066 | initPlayers = new ArrayList<T>(); |
067 | for ( int i = 0 ; i < k; i++) { |
068 | initPlayers.add(players.get(i)); |
069 | } |
070 | } |
071 | |
072 | public List<T>[] comput() { |
073 | List<T>[] results = new ArrayList[k]; |
074 | |
075 | boolean centerchange = true ; |
076 | while (centerchange) { |
077 | centerchange = false ; |
078 | for ( int i = 0 ; i < k; i++) { |
079 | results[i] = new ArrayList<T>(); |
080 | } |
081 | for ( int i = 0 ; i < players.size(); i++) { |
082 | T p = players.get(i); |
083 | double [] dists = new double [k]; |
084 | for ( int j = 0 ; j < initPlayers.size(); j++) { |
085 | T initP = initPlayers.get(j); |
086 | /* 计算距离 */ |
087 | double dist = distance(initP, p); |
088 | dists[j] = dist; |
089 | } |
090 | |
091 | int dist_index = computOrder(dists); |
092 | results[dist_index].add(p); |
093 | } |
094 | |
095 | for (int i = 0; i < k; i++) { |
096 | T player_new = findNewCenter(results[i]); |
097 | T player_old = initPlayers.get(i); |
098 | if (!IsPlayerEqual(player_new, player_old)) { |
099 | centerchange = true; |
100 | initPlayers.set(i, player_new); |
101 | } |
102 | |
103 | } |
104 | |
105 | } |
106 | |
107 | return results; |
108 | } |
109 | |
110 | /** |
111 | * 比较是否两个对象是否属性一致 |
112 | * |
113 | * @param p1 |
114 | * @param p2 |
115 | * @return |
116 | */ |
117 | public boolean IsPlayerEqual(T p1, T p2) { |
118 | if (p1 == p2) { |
119 | return true; |
120 | } |
121 | if (p1 == null || p2 == null) { |
122 | return false; |
123 | } |
124 | |
125 | |
126 | |
127 | boolean flag = true; |
128 | try { |
129 | for (int i = 0; i < fieldNames.size(); i++) { |
130 | String fieldName=fieldNames.get(i); |
131 | String getName = "get" |
132 | + fieldName.substring(0, 1).toUpperCase() |
133 | + fieldName.substring(1); |
134 | Object value1 = invokeMethod(p1,getName,null); |
135 | Object value2 = invokeMethod(p2,getName,null); |
136 | if (!value1.equals(value2)) { |
137 | flag = false; |
138 | break; |
139 | } |
140 | } |
141 | } catch (Exception e) { |
142 | e.printStackTrace(); |
143 | flag = false; |
144 | } |
145 | |
146 | return flag; |
147 | } |
148 | |
149 | /** |
150 | * 得到新聚类中心对象 |
151 | * |
152 | * @param ps |
153 | * @return |
154 | */ |
155 | public T findNewCenter(List<T> ps) { |
156 | try { |
157 | T t = classT.newInstance(); |
158 | if (ps == null || ps.size() == 0) { |
159 | return t; |
160 | } |
161 | |
162 | double[] ds = new double[fieldNames.size()]; |
163 | for (T vo : ps) { |
164 | for (int i = 0; i < fieldNames.size(); i++) { |
165 | String fieldName=fieldNames.get(i); |
166 | String getName = "get" |
167 | + fieldName.substring(0, 1).toUpperCase() |
168 | + fieldName.substring(1); |
169 | Object obj=invokeMethod(vo,getName,null); |
170 | Double fv=(obj==null?0:Double.parseDouble(obj+"")); |
171 | ds[i] += fv; |
172 | } |
173 | |
174 | } |
175 | |
176 | for (int i = 0; i < fieldNames.size(); i++) { |
177 | ds[i] = ds[i] / ps.size(); |
178 | String fieldName = fieldNames.get(i); |
179 | |
180 | /* 给对象设值 */ |
181 | String setName = "set" |
182 | + fieldName.substring(0, 1).toUpperCase() |
183 | + fieldName.substring(1); |
184 | |
185 | invokeMethod(t,setName,new Class[]{double.class},ds[i]); |
186 | |
187 | } |
188 | |
189 | return t; |
190 | } catch (Exception ex) { |
191 | ex.printStackTrace(); |
192 | } |
193 | return null; |
194 | |
195 | } |
196 | |
197 | /** |
198 | * 得到最短距离,并返回最短距离索引 |
199 | * |
200 | * @param dists |
201 | * @return |
202 | */ |
203 | public int computOrder(double[] dists) { |
204 | double min = 0; |
205 | int index = 0; |
206 | for (int i = 0; i < dists.length - 1; i++) { |
207 | double dist0 = dists[i]; |
208 | if (i == 0) { |
209 | min = dist0; |
210 | index = 0; |
211 | } |
212 | double dist1 = dists[i + 1]; |
213 | if (min > dist1) { |
214 | min = dist1; |
215 | index = i + 1; |
216 | } |
217 | } |
218 | |
219 | return index; |
220 | } |
221 | |
222 | /** |
223 | * 计算距离(相似性) 采用欧几里得算法 |
224 | * |
225 | * @param p0 |
226 | * @param p1 |
227 | * @return |
228 | */ |
229 | public double distance(T p0, T p1) { |
230 | double dis = 0; |
231 | try { |
232 | |
233 | for (int i = 0; i < fieldNames.size(); i++) { |
234 | String fieldName = fieldNames.get(i); |
235 | String getName = "get" |
236 | + fieldName.substring(0, 1).toUpperCase() |
237 | + fieldName.substring(1); |
238 | |
239 | Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+""); |
240 | Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+""); |
241 | dis += Math.pow(field0Value - field1Value, 2); |
242 | } |
243 | |
244 | } catch (Exception ex) { |
245 | ex.printStackTrace(); |
246 | } |
247 | return Math.sqrt(dis); |
248 | |
249 | } |
250 | |
251 | /*------公共方法-----*/ |
252 | public Object invokeMethod(Object owner, String methodName,Class[] argsClass, |
253 | Object... args) { |
254 | Class ownerClass = owner.getClass(); |
255 | try { |
256 | Method method=ownerClass.getDeclaredMethod(methodName,argsClass); |
257 | return method.invoke(owner, args); |
258 | } catch (SecurityException e) { |
259 | e.printStackTrace(); |
260 | } catch (NoSuchMethodException e) { |
261 | e.printStackTrace(); |
262 | } catch (Exception ex) { |
263 | ex.printStackTrace(); |
264 | } |
265 | |
266 | return null ; |
267 | } |
268 | |
269 | } |
最后咱们测试一下:
01 | package kmeans; |
02 | |
03 | import java.util.ArrayList; |
04 | import java.util.List; |
05 | import java.util.Random; |
06 | |
07 | public class TestMain { |
08 | |
09 | public static void main(String[] args) { |
10 | List<Player> listPlayers= new ArrayList<Player>(); |
11 | |
12 | for ( int i= 0 ;i< 15 ;i++){ |
13 | |
14 | Player p1= new Player(); |
15 | p1.setName( "afei-" +i); |
16 | p1.setAssists(i); |
17 | p1.setBackboard(i); |
18 | |
19 | //p1.setGoal(new Random(100*i).nextDouble()); |
20 | p1.setGoal(i* 10 ); |
21 | p1.setSteals(i); |
22 | //listPlayers.add(p1); |
23 | } |
24 | |
25 | Player p1= new Player(); |
26 | p1.setName( "afei1" ); |
27 | p1.setGoal( 1 ); |
28 | p1.setAssists( 8 ); |
29 | listPlayers.add(p1); |
30 | |
31 | Player p2= new Player(); |
32 | p2.setName( "afei2" ); |
33 | p2.setGoal( 2 ); |
34 | listPlayers.add(p2); |
35 | |
36 | Player p3= new Player(); |
37 | p3.setName( "afei3" ); |
38 | p3.setGoal( 3 ); |
39 | listPlayers.add(p3); |
40 | |
41 | Player p4= new Player(); |
42 | p4.setName( "afei4" ); |
43 | p4.setGoal( 7 ); |
44 | listPlayers.add(p4); |
45 | |
46 | Player p5= new Player(); |
47 | p5.setName( "afei5" ); |
48 | p5.setGoal( 8 ); |
49 | listPlayers.add(p5); |
50 | |
51 | Player p6= new Player(); |
52 | p6.setName( "afei6" ); |
53 | p6.setGoal( 25 ); |
54 | listPlayers.add(p6); |
55 | |
56 | Player p7= new Player(); |
57 | p7.setName( "afei7" ); |
58 | p7.setGoal( 26 ); |
59 | listPlayers.add(p7); |
60 | |
61 | Player p8= new Player(); |
62 | p8.setName( "afei8" ); |
63 | p8.setGoal( 27 ); |
64 | listPlayers.add(p8); |
65 | |
66 | Player p9= new Player(); |
67 | p9.setName( "afei9" ); |
68 | p9.setGoal( 28 ); |
69 | listPlayers.add(p9); |
70 | |
71 | |
72 | Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers, 3 ); |
73 | List<Player>[] results = kmeans.comput(); |
74 | for ( int i = 0 ; i < results.length; i++) { |
75 | System.out.println( "===========类别" + (i + 1 ) + "================" ); |
76 | List<Player> list = results[i]; |
77 | for (Player p : list) { |
78 | System.out.println(p.getName() + "--->" |
79 | + p.getGoal() + "," + p.getAssists() + "," |
80 | + p.getSteals() + "," + p.getBackboard()); |
81 | } |
82 | } |
83 | |
84 | |
85 | |
86 | |
87 | } |
88 | |
89 | } |
结果如下
这个里面涉及到相似度算法,事实证明欧几里得距离算法的实践效果是最优的。
最后说说kmeans算法的不足:可以看到只能针对数字类型的属性(维),对于其他类型的除非选定合适的数值度量