K-means算法
最近用到了k-means算法,这里简单实现一下,和大家交流进步,介绍就不写了,需要的朋友看这里k-means维基百科
实现
这里是一个比较简单的实现,参考下图的伪代码。
有几点细节列出来
- 迭代过程中最初的中心点是随机生成的,所以即使相同的样本也不能保证每次的迭代结果都是一样的
- K的选择
- 结合具体环境,样本分布等先验,知道是几个类别
- 尝试不同的K值并比较确定最终的K值
- 使用其他的聚类方法,如EM
- 聚类质量的评价
使用类别最小外接圆形的直径同类间距离的比值 - K-means的局限性
- 当类别分布不是圆形或者球形时,聚类效果相对较差
- 对外点敏感
# k-means clustering
import random
import time
import numpy as np
import matplotlib.pyplot as plt
def distance_l2(centroids, point):
"""
计算当前样本同各个类别中心点的距离
:param centroids: 样本中心点
:param point: 当前样本
:return: 当前样本所属类别
"""
dist_dist = [None]*3
for i, centroid in enumerate(centroids):
dist = np.sqrt((centroid[0]-point[0])**2 + (centroid[1]-point[1])**2)
dist_dist[i] = [dist, i]
dist_dist.sort(key=lambda elem: elem[0])
return dist_dist[0][1]
if __name__ == "__main__":
random.seed()
# 生成数据
class_n = 3 # 类别
class_i_size = 50 # 每个类别的样本个数
centers = [[40, 60], [100, 80], [75, 50]] # 生成数据时的样本中心点
samples = [[]]*class_n
icon = ['b.', 'r+', 'gx']
fig = plt.figure()
points_prop = {}
for i in range(class_n):
x, y = centers[i]
for j in range(class_i_size):
generateddata = (x + random.normalvariate(0, 8), y + random.normalvariate(0, 8))
if not generateddata in samples[i]: # 去掉重复数据
samples[i].append(generateddata)
plt.plot(generateddata[0], generateddata[1], icon[i])
plt.show()
max_iters = 10 # 最大迭代次数
samples_shuttle = samples[0] + samples[1] + samples[2] # 样本
size = len(samples_shuttle) # 样本大小
# 聚类
# 随机选择中心点
centroids = random.sample(samples_shuttle, class_n)
for i in range(max_iters):
print('iter:\t', i)
centroid_samples = [[], [], []]
for k in range(class_n):
plt.scatter(centroids[k][0], centroids[k][1], c=icon[k][0], s=100) # 绘制当前中心点
# 计算距离并归类
for j in range(size):
index = distance_l2(centroids, samples_shuttle[j])
centroid_samples[index].append(samples_shuttle[j])
plt.plot(samples_shuttle[j][0], samples_shuttle[j][1], icon[index])
# 如果其中一类的样本个数太少,则重新随机生成中心点
if True in (len(centroid_samples[0]) < 0.1*size,
len(centroid_samples[1]) < 0.1*size,
len(centroid_samples[2]) < 0.1*size):
centroids = random.sample(samples_shuttle, class_n)
i -= 1
print('continue')
continue
# 根据重新聚类结果计算中心点
static_flag = 0
for m in range(class_n):
aver_x, aver_y = 0., 0.
size_m = len(centroid_samples[m])
for p in centroid_samples[m]:
aver_x += p[0]
aver_y += p[1]
centroid_temp = [aver_x/size_m, aver_y/size_m]
if centroids[m] == centroid_temp:
static_flag += 1
centroids[m] = centroid_temp
time.sleep(1)
plt.show()
# 当聚类中心点不再变化时,聚类收敛,直接退出
if static_flag == class_n:
print("converge")
break
if i == max_iters-1:
# 一直迭代至最大次数退出
print('exit at max_iters')
print('end')
结果
这是代码的一次结果,较大的红绿蓝圆点是当前分布的中心点,第一张图是生成数据是各类别的分布情况,最后一张是聚类的结果,中间是聚类过程。