使用numpy实现,数据量大的时候可能需要微调。
# encoding: utf-8
"""
@author: Libing Wang
@time: 2021/3/26 9:51
@file: kmeans.py
@desc: 手动实现 k-means
"""
import random
import numpy as np
from matplotlib import pyplot as plt
def get_cluster(data, center_idx, k):
"""
进行每一轮的聚簇
:param data:待聚类的数据
:param center_idx:簇中心索引
:param k:簇的数量
:return:当前轮次, 所有的簇, 即每一个数据属于哪个簇
"""
center = data[center_idx]
data_idx = np.delete(range(data.shape[0]), center_idx)
sub_data = data[data_idx]
sub_data_rep = np.repeat(sub_data, k, axis=0)
center_rep = np.tile(center, (sub_data.shape[0], 1))
distance = np.sqrt(np.sum(np.abs(sub_data_rep - center_rep) ** 2, axis=1))
distance = np.reshape(distance, (-1, k))
idx = np.argmin(distance, axis=-1)
cluster = np.zeros(shape=(data.shape[0],))
cluster[center_idx] = np.arange(k)
cluster[data_idx] = idx
return cluster
def k_means(data, k):
"""
k-means算法实现
:param data: 待聚类的数据
:param k: 簇的数量
:return: 每一个数据属于哪个簇
"""
# 随机初始化簇心
center_idx = random.sample(range(data.shape[0]), k)
while True:
cluster = get_cluster(data, center_idx, k)
# 更新聚类中心
update_center_idx = []
for i in range(k):
idx = np.where(cluster == i)
c = np.mean(data[idx], axis=0)
distance = np.sqrt(np.sum(np.abs(data[idx] - c) ** 2, axis=1))
update_center_idx.append(idx[0][np.argmin(distance)])
if center_idx == update_center_idx:
break
center_idx = update_center_idx
return cluster
if __name__ == '__main__':
data = np.random.normal(3, 10, (500, 2))
cluster = k_means(data, 3)
print(cluster)
idx1 = np.where(cluster == 0)
idx2 = np.where(cluster == 1)
idx3 = np.where(cluster == 2)
plt.scatter(data[idx1, 0], data[idx1, 1], marker='*')
plt.scatter(data[idx2, 0], data[idx2, 1], marker='^')
plt.scatter(data[idx3, 0], data[idx3, 1], marker='o')
plt.show()
聚类结果: