基于python实现k-means聚类算法
说明:本文仅仅用代码实现自己对Kmeans聚类算法理解。不涉及该算法的原理及优化方法,然后利用生成的测试数据对聚类算法进行验证。后续将研究kmeans的优化方法及低复杂度的实现方法:
import numpy as np
import matplotlib.pyplot as plt
# from sklearn.cluster import KMeans
# 计算距离
def eucdistance(v1, v2):
# if v1.ndim == 1:
# v1 = v1.reshape(1, -1)
v1 = v1.reshape(1, -1) if v1.ndim == 1 else v1
v2 = (v2.reshape(1, -1) if v2.ndim == 1 else v2)
dist = np.sqrt(np.sum(np.power(v1 - v2, 2), 1))
return dist
def init_k_centers(x, k):
# 随机从数据集中挑选k个数据初始化kmeans聚类中心
assert x.ndim == 2, "x's dim not is 2"
k_centers = x[0: k]
return k_centers
# np.hstack((sample, sample))
def k_means_v1(x_data, k, ITERS, error_thred=1e-3, random_state=None, showfig=True, colors=["g", "b", "y", "k"]):
np.random.seed(random_state)
np.random.shuffle(x_data)
init_centers = init_k_centers(x_data, k=k)
for iter in range(ITERS):
results = {}
sample_label = []
for i in range(x_data.shape[0]):
init_dist = np.inf
sample = x_data[i]
# label = 0
for j_c in range(k):
dist = eucdistance(sample, init_centers[j_c])[0]
if dist < init_dist:
init_dist = dist
label = j_c
sample_label.append(label)
init_centers_t_1 = init_centers.copy()
# 更新聚类中心
for j_c in range(k):
x_data_k = x_data[np.array(sample_label) == j_c]
results[j_c] = x_data_k
init_centers[j_c] = x_data_k.mean(axis=0)
dist_center = eucdistance(init_centers_t_1, init_centers).mean() # 新的聚类中心和上一次聚类中心的距离
if showfig:
plt.ion()
plt.show()
plt.cla()
for key in results.keys():
plt.scatter(results[key][:, 0], results[key][:, 1], color=colors[key], marker=".", s=8)
plt.scatter(init_centers[key][0], init_centers[key][1], color=colors[key], marker="*", s=15)
plt.text(1.5, 0.5, "iter=%.2i, Loss=%.4f" % (iter, dist_center))
plt.pause(0.1)
plt.ioff()
plt.show()
if dist_center <= error_thred or iter == ITERS - 1:
print("*." * 20 + " k-means done " + "*." * 20)
return results
if __name__ == '__main__':
# x = np.random.randn(100, 2)
x_data = np.array([[3, 10], [2, 9], [1, 9], [3, 7], [4, 8], [3.5, 6], [9, 0.5],
[8, 3], [9, 2], [8, 1], [10, 0.5], [6.5, 2], [6.5, 4]])
results = k_means_v1(x_data=x_data, k=2, ITERS=20, error_thred=1e-3, random_state=42, showfig=True, colors=["g", "b", "y", "k"])