该代码是知乎作者iSun完成,其实现了高效的numpy运算。
import numpy as np
import matplotlib.pyplot as plt
def distance(vector1, vector2):
return np.sqrt(np.sum(np.square(vector2 - vector1)))
# 初始化k个中心点
def init_centroids(X, k, seed=110):
np.random.seed(seed)
n = X.shape[0]
assert n >= k
idxs = np.random.choice(range(n), k, replace=False)
return X[idxs]
# 计算k个cluster的中心点
def compute_centroids(X, clusters, k):
new_centroids = np.zeros((k, X.shape[1]))
for i in range(k):
new_centroids[i] = np.mean(X[clusters == i], axis=0)
return new_centroids
# 计算所有样本点和中心点的距离
def compute_distances(X, centers):
double_xy = 2 * X.dot(centers.T)
sq_X = np.sum(np.square(X), axis=1, keepdims=True)
sq_centers = np.sum(np.square(centers), axis=1)
dists = np.sqrt(sq_X - double_xy + sq_centers)
return dists
def kmeans(X, k, seed=110, tolerance=1e-5, max_iter=1000):
centers = init_centroids(X, k, seed)
dists = compute_distances(X, centers)
clusters = np.argmin(dists, axis=1)
i = 0
while True:
# 重新计算中心点
new_centers = compute_centroids(X, clusters, k)
if i > max_iter or distance(new_centers, centers) <= tolerance:
print(i)
print(distance(new_centers, centers))
break
centers = new_centers
dists = compute_distances(X, new_centers)
clusters = np.argmin(dists, axis=1)
i += 1
return clusters, centers
if __name__ == '__main__':
sample1 = np.random.normal(2, 1, 60).reshape((30, 2))
sample2 = np.random.normal(5, 1, 60).reshape((30, 2))
X = np.concatenate([sample1, sample2])
clusters, centers = kmeans(X, 2)
plt.scatter(sample1[:, 0], sample1[:, 1])
plt.scatter(sample2[:, 0], sample2[:, 1])
plt.scatter(centers[:, 0], centers[:, 1], marker='x', color='r', s=50)