K-Means算法是最古老也是应用最广泛的聚类算法,它使用质心定义原型,质心是一组点的均值,通常该算法用于n维连续空间中的对象。
K-Means算法流程
step1:选择K个点作为初始质心
step2:repeat
将每个点指派到最近的质心,形成K个簇
重新计算每个簇的质心
until 质心不在变化
![](http://static.oschina.net/uploads/img/201306/28152035_sYVx.png)
我们对每一个步骤都进行分析
step1:选择K个点作为初始质心
这一步首先要知道K的值,也就是说K是手动设置的,而不是像EM算法那样自动聚类成n个簇
其次,如何选择初始质心
最简单的方式无异于,随机选取质心了,然后多次运行,取效果最好的那个结果。这个方法,简单但不见得有效,有很大的可能是得到局部最优。
另一种复杂的方式是,随机选取一个质心,然后计算离这个质心最远的样本点,对于每个后继质心都选取已经选取过的质心的最远点。使用这种方式,可以确保质心是随机的,并且是散开的。
step2:repeat
将每个点指派到最近的质心,形成K个簇
重新计算每个簇的质心
until 质心不在变化
如何定义最近的概念,对于欧式空间中的点,可以使用欧式空间,对于文档可以用余弦相似性等等。对于给定的数据,可能适应与多种合适的邻近性度量。
其他问题
离群点的处理
离群点可能过度影响簇的发现,导致簇的最终发布会与我们的预想有较大出入,所以提前发现并剔除离群点是有必要的。
在我的工作中,是利用方差来剔除离群点,结果显示效果非常好。
簇分裂和簇合并
使用较大的K,往往会使得聚类的结果看上去更加合理,但很多情况下,我们并不想增加簇的个数。
这时可以交替采用簇分裂和簇合并。这种方式可以避开局部极小,并且能够得到具有期望个数簇的结果。
抽象了点,簇,和距离
Point.class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
public
class
Point {
private
double
x;
private
double
y;
private
int
id;
private
boolean
beyond;
//标识是否属于样本
public
Point(
int
id,
double
x,
double
y) {
this
.id = id;
this
.x = x;
this
.y = y;
this
.beyond =
true
;
}
public
Point(
int
id,
double
x,
double
y,
boolean
beyond) {
this
.id = id;
this
.x = x;
this
.y = y;
this
.beyond = beyond;
}
public
double
getX() {
return
x;
}
public
double
getY() {
return
y;
}
public
int
getId() {
return
id;
}
public
boolean
isBeyond() {
return
beyond;
}
@Override
public
String toString() {
return
"Point{"
+
"id="
+ id +
", x="
+ x +
", y="
+ y +
'}'
;
}
@Override
public
boolean
equals(Object o) {
if
(
this
== o)
return
true
;
if
(o ==
null
|| getClass() != o.getClass())
return
false
;
Point point = (Point) o;
if
(Double.compare(point.x, x) !=
0
)
return
false
;
if
(Double.compare(point.y, y) !=
0
)
return
false
;
return
true
;
}
@Override
public
int
hashCode() {
int
result;
long
temp;
temp = x != +
0
.0d ? Double.doubleToLongBits(x) : 0L;
result = (
int
) (temp ^ (temp >>>
32
));
temp = y != +
0
.0d ? Double.doubleToLongBits(y) : 0L;
result =
31
* result + (
int
) (temp ^ (temp >>>
32
));
return
result;
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
|
public
class
Cluster {
private
int
id;
//标识
private
Point center;
//中心
private
List<Point> members =
new
ArrayList<Point>();
//成员
public
Cluster(
int
id, Point center) {
this
.id = id;
this
.center = center;
}
public
Cluster(
int
id, Point center, List<Point> members) {
this
.id = id;
this
.center = center;
this
.members = members;
}
public
void
addPoint(Point newPoint) {
if
(!members.contains(newPoint))
members.add(newPoint);
else
throw
new
IllegalStateException(
"试图处理同一个样本数据!"
);
}
public
int
getId() {
return
id;
}
public
Point getCenter() {
return
center;
}
public
void
setCenter(Point center) {
this
.center = center;
}
public
List<Point> getMembers() {
return
members;
}
@Override
public
String toString() {
return
"Cluster{"
+
"id="
+ id +
", center="
+ center +
", members="
+ members +
"}"
;
}
}
|
1
2
3
|
public
abstract
class
AbstractDistance {
abstract
public
double
getDis(Point p1, Point p2);
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
public
class
Distence
implements
Comparable<Distence> {
private
Point source;
private
Point dest;
private
double
dis;
private
AbstractDistance distance;
public
Distence(Point source, Point dest, AbstractDistance distance) {
this
.source = source;
this
.dest = dest;
this
.distance = distance;
dis = distance.getDis(source, dest);
}
public
Point getSource() {
return
source;
}
public
Point getDest() {
return
dest;
}
public
double
getDis() {
return
dis;
}
@Override
public
int
compareTo(Distence o) {
if
(o.getDis() > dis)
return
-
1
;
else
return
1
;
}
}
|
核心实现类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
|
public
class
KMeansCluster {
private
int
k;
//簇的个数
private
int
num =
100000
;
//迭代次数
private
List<Double> datas;
//原始样本集
private
String address;
//样本集路径
private
List<Point> data =
new
ArrayList<Point>();
private
AbstractDistance distance =
new
AbstractDistance() {
@Override
public
double
getDis(Point p1, Point p2) {
//欧几里德距离
return
Math.sqrt(Math.pow(p1.getX() - p2.getX(),
2
) + Math.pow(p1.getY() - p2.getY(),
2
));
}
};
public
KMeansCluster(
int
k,
int
num, String address) {
this
.k = k;
this
.num = num;
this
.address = address;
}
public
KMeansCluster(
int
k, String address) {
this
.k = k;
this
.address = address;
}
public
KMeansCluster(
int
k, List<Double> datas) {
this
.k = k;
this
.datas = datas;
}
public
KMeansCluster(
int
k,
int
num, List<Double> datas) {
this
.k = k;
this
.num = num;
this
.datas = datas;
}
private
void
check() {
if
(k ==
0
)
throw
new
IllegalArgumentException(
"k must be the number > 0"
);
if
(address ==
null
&& datas ==
null
)
throw
new
IllegalArgumentException(
"program can't get real data"
);
}
/**
* 初始化数据
*
* @throws java.io.FileNotFoundException
*/
public
void
init()
throws
FileNotFoundException {
check();
//读取文件,init data
//处理原始数据
for
(
int
i =
0
, j = datas.size(); i < j; i++)
data.add(
new
Point(i, datas.get(i),
0
));
}
/**
* 第一次随机选取中心点
*
* @return
*/
public
Set<Point> chooseCenter() {
Set<Point> center =
new
HashSet<Point>();
Random ran =
new
Random();
int
roll =
0
;
while
(center.size() < k) {
roll = ran.nextInt(data.size());
center.add(data.get(roll));
}
return
center;
}
/**
* @param center
* @return
*/
public
List<Cluster> prepare(Set<Point> center) {
List<Cluster> cluster =
new
ArrayList<Cluster>();
Iterator<Point> it = center.iterator();
int
id =
0
;
while
(it.hasNext()) {
Point p = it.next();
if
(p.isBeyond()) {
Cluster c =
new
Cluster(id++, p);
c.addPoint(p);
cluster.add(c);
}
else
cluster.add(
new
Cluster(id++, p));
}
return
cluster;
}
/**
* 第一次运算,中心点为样本值
*
* @param center
* @param cluster
* @return
*/
public
List<Cluster> clustering(Set<Point> center, List<Cluster> cluster) {
Point[] p = center.toArray(
new
Point[
0
]);
TreeSet<Distence> distence =
new
TreeSet<Distence>();
//存放距离信息
Point source;
Point dest;
boolean
flag =
false
;
for
(
int
i =
0
, n = data.size(); i < n; i++) {
distence.clear();
for
(
int
j =
0
; j < center.size(); j++) {
if
(center.contains(data.get(i)))
break
;
flag =
true
;
// 计算距离
source = data.get(i);
dest = p[j];
distence.add(
new
Distence(source, dest, distance));
}
if
(flag ==
true
) {
Distence min = distence.first();
for
(
int
m =
0
, k = cluster.size(); m < k; m++) {
if
(cluster.get(m).getCenter().equals(min.getDest()))
cluster.get(m).addPoint(min.getSource());
}
}
flag =
false
;
}
return
cluster;
}
/**
* 迭代运算,中心点为簇内样本均值
*
* @param cluster
* @return
*/
public
List<Cluster> cluster(List<Cluster> cluster) {
// double error;
Set<Point> lastCenter =
new
HashSet<Point>();
for
(
int
m =
0
; m < num; m++) {
// error = 0;
Set<Point> center =
new
HashSet<Point>();
// 重新计算聚类中心
for
(
int
j =
0
; j < k; j++) {
List<Point> ps = cluster.get(j).getMembers();
int
size = ps.size();
if
(size <
3
) {
center.add(cluster.get(j).getCenter());
continue
;
}
// 计算距离
double
x =
0.0
, y =
0.0
;
for
(
int
k1 =
0
; k1 < size; k1++) {
x += ps.get(k1).getX();
y += ps.get(k1).getY();
}
//得到新的中心点
Point nc =
new
Point(-
1
, x / size, y / size,
false
);
center.add(nc);
}
if
(lastCenter.containsAll(center))
//中心点不在变化,退出迭代
break
;
lastCenter = center;
// 迭代运算
cluster = clustering(center, prepare(center));
// for (int nz = 0; nz < k; nz++) {
// error += cluster.get(nz).getError();//计算误差
// }
}
return
cluster;
}
/**
* 输出聚类信息到控制台
*
* @param cs
*/
public
void
out2console(List<Cluster> cs) {
for
(
int
i =
0
; i < cs.size(); i++) {
System.out.println(
"No."
+ (i +
1
) +
" cluster:"
);
Cluster c = cs.get(i);
List<Point> p = c.getMembers();
for
(
int
j =
0
; j < p.size(); j++) {
System.out.println(
"\t"
+ p.get(j).getX() +
" "
);
}
System.out.println();
}
}
}
|
K-Means算法是最古老也是应用最广泛的聚类算法,它使用质心定义原型,质心是一组点的均值,通常该算法用于n维连续空间中的对象。
例如下图的样本集,初始选择是三个质心比较集中,但是迭代3次之后,质心趋于稳定,并将样本集分为3部分
贴上代码java版,以后有时间写个python版的
抽象了点,簇,和距离
Point.class
Cluster.class
抽象的距离,可以具体实现为欧式,曼式或其他距离公式
点对
代码还没有仔细优化,执行的效率可能还存在一定的问题
K-Means算法流程
step1:选择K个点作为初始质心
step2:repeat
将每个点指派到最近的质心,形成K个簇
重新计算每个簇的质心
until 质心不在变化
![](http://static.oschina.net/uploads/img/201306/28152035_sYVx.png)
我们对每一个步骤都进行分析
step1:选择K个点作为初始质心
这一步首先要知道K的值,也就是说K是手动设置的,而不是像EM算法那样自动聚类成n个簇
其次,如何选择初始质心
最简单的方式无异于,随机选取质心了,然后多次运行,取效果最好的那个结果。这个方法,简单但不见得有效,有很大的可能是得到局部最优。
另一种复杂的方式是,随机选取一个质心,然后计算离这个质心最远的样本点,对于每个后继质心都选取已经选取过的质心的最远点。使用这种方式,可以确保质心是随机的,并且是散开的。
step2:repeat
将每个点指派到最近的质心,形成K个簇
重新计算每个簇的质心
until 质心不在变化
如何定义最近的概念,对于欧式空间中的点,可以使用欧式空间,对于文档可以用余弦相似性等等。对于给定的数据,可能适应与多种合适的邻近性度量。
其他问题
离群点的处理
离群点可能过度影响簇的发现,导致簇的最终发布会与我们的预想有较大出入,所以提前发现并剔除离群点是有必要的。
在我的工作中,是利用方差来剔除离群点,结果显示效果非常好。
簇分裂和簇合并
使用较大的K,往往会使得聚类的结果看上去更加合理,但很多情况下,我们并不想增加簇的个数。
这时可以交替采用簇分裂和簇合并。这种方式可以避开局部极小,并且能够得到具有期望个数簇的结果。
抽象了点,簇,和距离
Point.class
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
|
public
class
Point {
private
double
x;
private
double
y;
private
int
id;
private
boolean
beyond;
//标识是否属于样本
public
Point(
int
id,
double
x,
double
y) {
this
.id = id;
this
.x = x;
this
.y = y;
this
.beyond =
true
;
}
public
Point(
int
id,
double
x,
double
y,
boolean
beyond) {
this
.id = id;
this
.x = x;
this
.y = y;
this
.beyond = beyond;
}
public
double
getX() {
return
x;
}
public
double
getY() {
return
y;
}
public
int
getId() {
return
id;
}
public
boolean
isBeyond() {
return
beyond;
}
@Override
public
String toString() {
return
"Point{"
+
"id="
+ id +
", x="
+ x +
", y="
+ y +
'}'
;
}
@Override
public
boolean
equals(Object o) {
if
(
this
== o)
return
true
;
if
(o ==
null
|| getClass() != o.getClass())
return
false
;
Point point = (Point) o;
if
(Double.compare(point.x, x) !=
0
)
return
false
;
if
(Double.compare(point.y, y) !=
0
)
return
false
;
return
true
;
}
@Override
public
int
hashCode() {
int
result;
long
temp;
temp = x != +
0
.0d ? Double.doubleToLongBits(x) : 0L;
result = (
int
) (temp ^ (temp >>>
32
));
temp = y != +
0
.0d ? Double.doubleToLongBits(y) : 0L;
result =
31
* result + (
int
) (temp ^ (temp >>>
32
));
return
result;
}
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
|
public
class
Cluster {
private
int
id;
//标识
private
Point center;
//中心
private
List<Point> members =
new
ArrayList<Point>();
//成员
public
Cluster(
int
id, Point center) {
this
.id = id;
this
.center = center;
}
public
Cluster(
int
id, Point center, List<Point> members) {
this
.id = id;
this
.center = center;
this
.members = members;
}
public
void
addPoint(Point newPoint) {
if
(!members.contains(newPoint))
members.add(newPoint);
else
throw
new
IllegalStateException(
"试图处理同一个样本数据!"
);
}
public
int
getId() {
return
id;
}
public
Point getCenter() {
return
center;
}
public
void
setCenter(Point center) {
this
.center = center;
}
public
List<Point> getMembers() {
return
members;
}
@Override
public
String toString() {
return
"Cluster{"
+
"id="
+ id +
", center="
+ center +
", members="
+ members +
"}"
;
}
}
|
1
2
3
|
public
abstract
class
AbstractDistance {
abstract
public
double
getDis(Point p1, Point p2);
}
|
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
|
public
class
Distence
implements
Comparable<Distence> {
private
Point source;
private
Point dest;
private
double
dis;
private
AbstractDistance distance;
public
Distence(Point source, Point dest, AbstractDistance distance) {
this
.source = source;
this
.dest = dest;
this
.distance = distance;
dis = distance.getDis(source, dest);
}
public
Point getSource() {
return
source;
}
public
Point getDest() {
return
dest;
}
public
double
getDis() {
return
dis;
}
@Override
public
int
compareTo(Distence o) {
if
(o.getDis() > dis)
return
-
1
;
else
return
1
;
}
}
|
核心实现类
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
|
public
class
KMeansCluster {
private
int
k;
//簇的个数
private
int
num =
100000
;
//迭代次数
private
List<Double> datas;
//原始样本集
private
String address;
//样本集路径
private
List<Point> data =
new
ArrayList<Point>();
private
AbstractDistance distance =
new
AbstractDistance() {
@Override
public
double
getDis(Point p1, Point p2) {
//欧几里德距离
return
Math.sqrt(Math.pow(p1.getX() - p2.getX(),
2
) + Math.pow(p1.getY() - p2.getY(),
2
));
}
};
public
KMeansCluster(
int
k,
int
num, String address) {
this
.k = k;
this
.num = num;
this
.address = address;
}
public
KMeansCluster(
int
k, String address) {
this
.k = k;
this
.address = address;
}
public
KMeansCluster(
int
k, List<Double> datas) {
this
.k = k;
this
.datas = datas;
}
public
KMeansCluster(
int
k,
int
num, List<Double> datas) {
this
.k = k;
this
.num = num;
this
.datas = datas;
}
private
void
check() {
if
(k ==
0
)
throw
new
IllegalArgumentException(
"k must be the number > 0"
);
if
(address ==
null
&& datas ==
null
)
throw
new
IllegalArgumentException(
"program can't get real data"
);
}
/**
* 初始化数据
*
* @throws java.io.FileNotFoundException
*/
public
void
init()
throws
FileNotFoundException {
check();
//读取文件,init data
//处理原始数据
for
(
int
i =
0
, j = datas.size(); i < j; i++)
data.add(
new
Point(i, datas.get(i),
0
));
}
/**
* 第一次随机选取中心点
*
* @return
*/
public
Set<Point> chooseCenter() {
Set<Point> center =
new
HashSet<Point>();
Random ran =
new
Random();
int
roll =
0
;
while
(center.size() < k) {
roll = ran.nextInt(data.size());
center.add(data.get(roll));
}
return
center;
}
/**
* @param center
* @return
*/
public
List<Cluster> prepare(Set<Point> center) {
List<Cluster> cluster =
new
ArrayList<Cluster>();
Iterator<Point> it = center.iterator();
int
id =
0
;
while
(it.hasNext()) {
Point p = it.next();
if
(p.isBeyond()) {
Cluster c =
new
Cluster(id++, p);
c.addPoint(p);
cluster.add(c);
}
else
cluster.add(
new
Cluster(id++, p));
}
return
cluster;
}
/**
* 第一次运算,中心点为样本值
*
* @param center
* @param cluster
* @return
*/
public
List<Cluster> clustering(Set<Point> center, List<Cluster> cluster) {
Point[] p = center.toArray(
new
Point[
0
]);
TreeSet<Distence> distence =
new
TreeSet<Distence>();
//存放距离信息
Point source;
Point dest;
boolean
flag =
false
;
for
(
int
i =
0
, n = data.size(); i < n; i++) {
distence.clear();
for
(
int
j =
0
; j < center.size(); j++) {
if
(center.contains(data.get(i)))
break
;
flag =
true
;
// 计算距离
source = data.get(i);
dest = p[j];
distence.add(
new
Distence(source, dest, distance));
}
if
(flag ==
true
) {
Distence min = distence.first();
for
(
int
m =
0
, k = cluster.size(); m < k; m++) {
if
(cluster.get(m).getCenter().equals(min.getDest()))
cluster.get(m).addPoint(min.getSource());
}
}
flag =
false
;
}
return
cluster;
}
/**
* 迭代运算,中心点为簇内样本均值
*
* @param cluster
* @return
*/
public
List<Cluster> cluster(List<Cluster> cluster) {
// double error;
Set<Point> lastCenter =
new
HashSet<Point>();
for
(
int
m =
0
; m < num; m++) {
// error = 0;
Set<Point> center =
new
HashSet<Point>();
// 重新计算聚类中心
for
(
int
j =
0
; j < k; j++) {
List<Point> ps = cluster.get(j).getMembers();
int
size = ps.size();
if
(size <
3
) {
center.add(cluster.get(j).getCenter());
continue
;
}
// 计算距离
double
x =
0.0
, y =
0.0
;
for
(
int
k1 =
0
; k1 < size; k1++) {
x += ps.get(k1).getX();
y += ps.get(k1).getY();
}
//得到新的中心点
Point nc =
new
Point(-
1
, x / size, y / size,
false
);
center.add(nc);
}
if
(lastCenter.containsAll(center))
//中心点不在变化,退出迭代
break
;
lastCenter = center;
// 迭代运算
cluster = clustering(center, prepare(center));
// for (int nz = 0; nz < k; nz++) {
// error += cluster.get(nz).getError();//计算误差
// }
}
return
cluster;
}
/**
* 输出聚类信息到控制台
*
* @param cs
*/
public
void
out2console(List<Cluster> cs) {
for
(
int
i =
0
; i < cs.size(); i++) {
System.out.println(
"No."
+ (i +
1
) +
" cluster:"
);
Cluster c = cs.get(i);
List<Point> p = c.getMembers();
for
(
int
j =
0
; j < p.size(); j++) {
System.out.println(
"\t"
+ p.get(j).getX() +
" "
);
}
System.out.println();
}
}
}
|