聚类算法之kmeans算法java版本



聚类的意思很明确,物以类聚,把类似的事物放在一起。
聚类算法是web智能中很重要的一步,可运用在社交,新闻,电商等各种应用中,我打算专门开个分类讲解聚类各种算法的java版实现。
首先介绍kmeans算法。
kmeans算法的速度很快,性能良好,几乎是应用最广泛的,它需要先指定聚类的个数k,然后根据k值来自动分出k个类别集合。
举个例子,某某教练在得到全队的数据后,想把这些球员自动分成不同的组别,你得问教练需要分成几个组,他回答你k个,ok可以开始了,在解决这个问题之前有必要详细了解自己需要达到的目的:根据教练给出的k值,呈现出k个组,每个组的队员是相似的。
首先,我们创建球员类。

 

01
package kmeans;
02

03
/**
04
* 球员
05
*
06
* @author 阿飞哥
07
*
08
*/
09
public class Player {
10

11
private int id;
12
private String name;
13

14
private int age;
15

16
/* 得分 */
17
@KmeanField
18
private double goal;
19

20
/* 助攻 */
21
//@KmeanField
22
private double assists;
23

24
/* 篮板 */
25
//@KmeanField
26
private double backboard;
27

28
/* 抢断 */
29
//@KmeanField
30
private double steals;
31

32
public int getId() {
33
return id;
34
}
35

36
public void setId(int id) {
37
this.id = id;
38
}
39

40
public String getName() {
41
return name;
42
}
43

44
public void setName(String name) {
45
this.name = name;
46
}
47

48
public int getAge() {
49
return age;
50
}
51

52
public void setAge(int age) {
53
this.age = age;
54
}
55

56
public double getGoal() {
57
return goal;
58
}
59

60
public void setGoal(double goal) {
61
this.goal = goal;
62
}
63

64
public double getAssists() {
65
return assists;
66
}
67

68
public void setAssists(double assists) {
69
this.assists = assists;
70
}
71

72
public double getBackboard() {
73
return backboard;
74
}
75

76
public void setBackboard(double backboard) {
77
this.backboard = backboard;
78
}
79

80
public double getSteals() {
81
return steals;
82
}
83

84
public void setSteals(double steals) {
85
this.steals = steals;
86
}
87

88

89
}

@KmeanField这个注解是自定义的,用来标示这个属性是否是算法需要的维度。
代码如下
01
package kmeans;
02

03
import java.lang.annotation.ElementType;
04
import java.lang.annotation.Retention;
05
import java.lang.annotation.RetentionPolicy;
06
import java.lang.annotation.Target;
07

08
/**
09
* 在对象的属性上标注此注释,
10
* 表示纳入kmeans算法,仅支持数值类属性
11
* @author 阿飞哥
12
*/
13
@Retention(RetentionPolicy.RUNTIME)
14
@Target(ElementType.FIELD)
15
public @interface KmeanField {
16
}
接下来看看最核心的kmeans算法,具体实现过程如下:
1,初始化k个聚类中心
2,计算出每个对象跟这k个中心的距离(相似度计算,这个下面会提到),假如x这个对象跟y这个中心的距离最小(相似度最大),那么x属于y这个中心。这一步就可以得到初步的k个聚类
3,在第二步得到的每个聚类分别计算出新的聚类中心,和旧的中心比对,假如不相同,则继续第2步,直到新旧两个中心相同,说明聚类不可变,已经成功

实现代码如下:
001
package kmeans;
002

003
import java.lang.annotation.Annotation;
004
import java.lang.reflect.Field;
005
import java.lang.reflect.Method;
006
import java.util.ArrayList;
007
import java.util.List;
008

009
/**
010
*
011
* @author 阿飞哥
012
*
013
*/
014
public class Kmeans<T> {
015

016
/**
017
* 所有数据列表
018
*/
019
private List<T> players = new ArrayList<T>();
020

021
/**
022
* 数据类别
023
*/
024
private Class<T> classT;
025

026
/**
027
* 初始化列表
028
*/
029
private List<T> initPlayers;
030

031
/**
032
* 需要纳入kmeans算法的属性名称
033
*/
034
private List<String> fieldNames = new ArrayList<String>();
035

036
/**
037
* 分类数
038
*/
039
private int k = 1;
040

041
public Kmeans() {
042

043
}
044

045
/**
046
* 初始化列表
047
*
048
* @param list
049
* @param k
050
*/
051
public Kmeans(List<T> list, int k) {
052
this.players = list;
053
this.k = k;
054
T t = list.get(0);
055
this.classT = (Class<T>) t.getClass();
056
Field[] fields = this.classT.getDeclaredFields();
057
for (int i = 0; i < fields.length; i++) {
058
Annotation kmeansAnnotation = fields[i]
059
.getAnnotation(KmeanField.class);
060
if (kmeansAnnotation != null) {
061
fieldNames.add(fields[i].getName());
062
}
063

064
}
065

066
initPlayers = new ArrayList<T>();
067
for (int i = 0; i < k; i++) {
068
initPlayers.add(players.get(i));
069
}
070
}
071

072
public List<T>[] comput() {
073
List<T>[] results = new ArrayList[k];
074

075
boolean centerchange = true;
076
while (centerchange) {
077
centerchange = false;
078
for (int i = 0; i < k; i++) {
079
results[i] = new ArrayList<T>();
080
}
081
for (int i = 0; i < players.size(); i++) {
082
T p = players.get(i);
083
double[] dists = new double[k];
084
for (int j = 0; j < initPlayers.size(); j++) {
085
T initP = initPlayers.get(j);
086
/* 计算距离 */
087
double dist = distance(initP, p);
088
dists[j] = dist;
089
}
090

091
int dist_index = computOrder(dists);
092
results[dist_index].add(p);
093
}
094

095
for (int i = 0; i < k; i++) {
096
T player_new = findNewCenter(results[i]);
097
T player_old = initPlayers.get(i);
098
if (!IsPlayerEqual(player_new, player_old)) {
099
centerchange = true;
100
initPlayers.set(i, player_new);
101
}
102

103
}
104

105
}
106

107
return results;
108
}
109

110
/**
111
* 比较是否两个对象是否属性一致
112
*
113
* @param p1
114
* @param p2
115
* @return
116
*/
117
public boolean IsPlayerEqual(T p1, T p2) {
118
if (p1 == p2) {
119
return true;
120
}
121
if (p1 == null || p2 == null) {
122
return false;
123
}
124

125

126

127
boolean flag = true;
128
try {
129
for (int i = 0; i < fieldNames.size(); i++) {
130
String fieldName=fieldNames.get(i);
131
String getName = "get"
132
+ fieldName.substring(0, 1).toUpperCase()
133
+ fieldName.substring(1);
134
Object value1 = invokeMethod(p1,getName,null);
135
Object value2 = invokeMethod(p2,getName,null);
136
if (!value1.equals(value2)) {
137
flag = false;
138
break;
139
}
140
}
141
} catch (Exception e) {
142
e.printStackTrace();
143
flag = false;
144
}
145

146
return flag;
147
}
148

149
/**
150
* 得到新聚类中心对象
151
*
152
* @param ps
153
* @return
154
*/
155
public T findNewCenter(List<T> ps) {
156
try {
157
T t = classT.newInstance();
158
if (ps == null || ps.size() == 0) {
159
return t;
160
}
161

162
double[] ds = new double[fieldNames.size()];
163
for (T vo : ps) {
164
for (int i = 0; i < fieldNames.size(); i++) {
165
String fieldName=fieldNames.get(i);
166
String getName = "get"
167
+ fieldName.substring(0, 1).toUpperCase()
168
+ fieldName.substring(1);
169
Object obj=invokeMethod(vo,getName,null);
170
Double fv=(obj==null?0:Double.parseDouble(obj+""));
171
ds[i] += fv;
172
}
173

174
}
175

176
for (int i = 0; i < fieldNames.size(); i++) {
177
ds[i] = ds[i] / ps.size();
178
String fieldName = fieldNames.get(i);
179

180
/* 给对象设值 */
181
String setName = "set"
182
+ fieldName.substring(0, 1).toUpperCase()
183
+ fieldName.substring(1);
184

185
invokeMethod(t,setName,new Class[]{double.class},ds[i]);
186

187
}
188

189
return t;
190
} catch (Exception ex) {
191
ex.printStackTrace();
192
}
193
return null;
194

195
}
196

197
/**
198
* 得到最短距离,并返回最短距离索引
199
*
200
* @param dists
201
* @return
202
*/
203
public int computOrder(double[] dists) {
204
double min = 0;
205
int index = 0;
206
for (int i = 0; i < dists.length - 1; i++) {
207
double dist0 = dists[i];
208
if (i == 0) {
209
min = dist0;
210
index = 0;
211
}
212
double dist1 = dists[i + 1];
213
if (min > dist1) {
214
min = dist1;
215
index = i + 1;
216
}
217
}
218

219
return index;
220
}
221

222
/**
223
* 计算距离(相似性) 采用欧几里得算法
224
*
225
* @param p0
226
* @param p1
227
* @return
228
*/
229
public double distance(T p0, T p1) {
230
double dis = 0;
231
try {
232

233
for (int i = 0; i < fieldNames.size(); i++) {
234
String fieldName = fieldNames.get(i);
235
String getName = "get"
236
+ fieldName.substring(0, 1).toUpperCase()
237
+ fieldName.substring(1);
238

239
Double field0Value=Double.parseDouble(invokeMethod(p0,getName,null)+"");
240
Double field1Value=Double.parseDouble(invokeMethod(p1,getName,null)+"");
241
dis += Math.pow(field0Value - field1Value, 2);
242
}
243

244
} catch (Exception ex) {
245
ex.printStackTrace();
246
}
247
return Math.sqrt(dis);
248

249
}
250

251
/*------公共方法-----*/
252
public Object invokeMethod(Object owner, String methodName,Class[] argsClass,
253
Object... args) {
254
Class ownerClass = owner.getClass();
255
try {
256
Method method=ownerClass.getDeclaredMethod(methodName,argsClass);
257
return method.invoke(owner, args);
258
} catch (SecurityException e) {
259
e.printStackTrace();
260
} catch (NoSuchMethodException e) {
261
e.printStackTrace();
262
} catch (Exception ex) {
263
ex.printStackTrace();
264
}
265

266
return null;
267
}
268

269
}
最后咱们测试一下:
01
package kmeans;
02

03
import java.util.ArrayList;
04
import java.util.List;
05
import java.util.Random;
06

07
public class TestMain {
08

09
public static void main(String[] args) {
10
List<Player> listPlayers=new ArrayList<Player>();
11

12
for(int i=0;i<15;i++){
13

14
Player p1=new Player();
15
p1.setName("afei-"+i);
16
p1.setAssists(i);
17
p1.setBackboard(i);
18

19
//p1.setGoal(new Random(100*i).nextDouble());
20
p1.setGoal(i*10);
21
p1.setSteals(i);
22
//listPlayers.add(p1);
23
}
24

25
Player p1=new Player();
26
p1.setName("afei1");
27
p1.setGoal(1);
28
p1.setAssists(8);
29
listPlayers.add(p1);
30

31
Player p2=new Player();
32
p2.setName("afei2");
33
p2.setGoal(2);
34
listPlayers.add(p2);
35

36
Player p3=new Player();
37
p3.setName("afei3");
38
p3.setGoal(3);
39
listPlayers.add(p3);
40

41
Player p4=new Player();
42
p4.setName("afei4");
43
p4.setGoal(7);
44
listPlayers.add(p4);
45

46
Player p5=new Player();
47
p5.setName("afei5");
48
p5.setGoal(8);
49
listPlayers.add(p5);
50

51
Player p6=new Player();
52
p6.setName("afei6");
53
p6.setGoal(25);
54
listPlayers.add(p6);
55

56
Player p7=new Player();
57
p7.setName("afei7");
58
p7.setGoal(26);
59
listPlayers.add(p7);
60

61
Player p8=new Player();
62
p8.setName("afei8");
63
p8.setGoal(27);
64
listPlayers.add(p8);
65

66
Player p9=new Player();
67
p9.setName("afei9");
68
p9.setGoal(28);
69
listPlayers.add(p9);
70

71

72
Kmeans<Player> kmeans = new Kmeans<Player>(listPlayers,3);
73
List<Player>[] results = kmeans.comput();
74
for (int i = 0; i < results.length; i++) {
75
System.out.println("===========类别" + (i + 1) + "================");
76
List<Player> list = results[i];
77
for (Player p : list) {
78
System.out.println(p.getName() + "--->"
79
+ p.getGoal() + "," + p.getAssists() + ","
80
+ p.getSteals() + "," + p.getBackboard());
81
}
82
}
83

84

85

86

87
}
88

89
}
结果如下

 

这个里面涉及到相似度算法,事实证明欧几里得距离算法的实践效果是最优的。
最后说说kmeans算法的不足:可以看到只能针对数字类型的属性(维),对于其他类型的除非选定合适的数值度量。

转载于:https://www.cnblogs.com/gaiwen/articles/2955999.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值