k-means算法究极无敌简洁实现(pytorch)
def kmeans(x, ncluster, niter=10):
'''
x : torch.tensor(data_num,data_dim)
ncluster : The number of clustering for data_num
niter : Number of iterations for kmeans
'''
N, D = x.size()
c = x[torch.randperm(N)[:ncluster]] # init clusters at random
for i in range(niter):
# assign all pixels to the closest codebook element
# .argmin(1) : 按列取最小值的下标,下面这行的意思是将x.size(0)个数据点归类到random选出的ncluster类
a = ((x[:, None, :] - c[None, :, :])**2).sum(-1).argmin(1)
# move each codebook element to be the mean of the pixels that assigned to it
# 计算每一类的迭代中心,然后重新把第一轮随机选出的聚类中心移到这一类的中心处
c = torch.stack([x[a==k].mean(0) for k in range(ncluster)])
# re-assign any poorly positioned codebook elements
nanix = torch.any(torch.isnan(c), dim=1)
ndead = nanix.sum().item()
print('done step %d/%d, re-initialized %d dead clusters' % (i+1, niter, ndead))
c[nanix] = x[torch.randperm(N)[:ndead]] # re-init dead clusters
return c
摘录自minGPT实现代码。