code
import numpy as np
import matplotlib.pyplot as plt
import random
def kMeans(daSet, K):
centerPoint = random.sample(daSet, K)
center_dis = 9999
iterNum = 0
while np.mean(center_dis) > 1e-6:
ind = assignment(daSet, centerPoint)
newCenterPoint = update(daSet, ind, K)
center_dis = [calDis(centerPoint[i],newCenterPoint[i]) for i in range(K)]
centerPoint = newCenterPoint #new iteration
iterNum += 1
print("cent Dis: ", center_dis, np.mean(center_dis))
return newCenterPoint, ind, iterNum
def assignment(daSet, centerPoint):
dis_all = [[calDis(point, center) for center in centerPoint] for point in daSet]
ind = np.array(dis_all).argmin(axis=1) #注意不是min或nanmin求极值,而是求极值对应的序号,以后续选择cluster归属
print(dis_all,"\n ind= ", ind)
return ind
def update(daSet, ind, K):
newCenterPoint = []#np.array([[]])
for i in range(0,K):
clu_i_ind = np.argwhere(np.array(ind)==i)
#print(clu_i_ind, clu_i_ind.__class__)
clu = np.array(daSet)[clu_i_ind]
newCenterPoint.append((np.mean(clu, axis=0)))
#newCenterPoint = np.append(newCenterPoint, np.mean(clu, axis=0), axis = 0)
print("clu", clu, np.array(newCenterPoint).squeeze())
return np.array(newCenterPoint).squeeze()
def calDis(p1,p2):
#这里采用欧式距
return np.sqrt(np.sum(np.power(np.array(p1)-np.array(p2),2)))
def createDataSet():
return [[1,1], [2,1], [3,1], [2,3], [10,9], [9,9],[7,8],[6,5],[10,11], [3,3],[13,12], [13,10],[-5,-2],[-10,-9],[-8,-7]]
if __name__ == "__main__":
daSet = createDataSet()
newCenterPoint, ind, iterNum = kMeans(daSet, 3)
cl = ['bs','k<','c*','yo']
print("Final", newCenterPoint, ind, "\n iterative rounds:\n",iterNum)
plt.figure()
plt.plot(np.array(newCenterPoint)[:,0], np.array(newCenterPoint)[:,1], 'xr')
for ii, ele in enumerate(daSet):
plt.plot(ele[0], ele[1], cl[ ind[ii]]) #ind[ii]指明所属,利用cl[]区分不同属的个体;想要输出clu,可以argwhere查找ind==i
plt.show()
- 主要函数为:创建数据(也可以用正态分布随机生成,效果更好)、控制流程的k-means、主要的分配assigment和更新update两个迭代步骤;分别负责按照距离最小分类、和计算新的类中心;根据中心点不变来判决收敛了; 最后mian函数里图形输出
results
来源
https://en.wikipedia.org/wiki/K-means_clustering
问题
- 数据存储结构设计为list还是array,有问题;array方便处理,list更方便创建和刷新
- 如何根据索引分类:采用ind = np.argmin, cl_ind = np.argwhere(ind), dataSet(cl_ind)输出每一类
改进生数
import numpy as np
import matplotlib.pyplot as plt
import random
cl = ['bs', 'k<', 'c*', 'yo']
def kMeans(daSet, K):
centerPoint = random.sample(daSet, K)
center_dis = 9999
iterNum = 0
while np.mean(center_dis) > 1e-6:
ind = assignment(daSet, centerPoint)
newCenterPoint = update(daSet, ind, K)
center_dis = [calDis(centerPoint[i],newCenterPoint[i]) for i in range(K)]
centerPoint = newCenterPoint #new iteration
iterNum += 1
print("cent Dis: ", center_dis, np.mean(center_dis))
return newCenterPoint, ind, iterNum
def assignment(daSet, centerPoint):
dis_all = [[calDis(point, center) for center in centerPoint] for point in daSet]
ind = np.array(dis_all).argmin(axis=1) #注意不是min或nanmin求极值,而是求极值对应的序号,以后续选择cluster归属
print(dis_all,"\n ind= ", ind)
return ind
def update(daSet, ind, K):
newCenterPoint = []#np.array([[]])
for i in range(0,K):
clu_i_ind = np.argwhere(np.array(ind)==i)
#print(clu_i_ind, clu_i_ind.__class__)
clu = np.array(daSet)[clu_i_ind]
newCenterPoint.append((np.mean(clu, axis=0)))
#newCenterPoint = np.append(newCenterPoint, np.mean(clu, axis=0), axis = 0)
print("clu", clu, np.array(newCenterPoint).squeeze())
return np.array(newCenterPoint).squeeze()
def calDis(p1,p2):
#这里采用欧式距
return np.sqrt(np.sum(np.power(np.array(p1)-np.array(p2),2)))
def createDataSet(flag = "simple"):
if flag == "norm":
num = 100
norm1 = (5 * np.random.randn(num,2) - np.array([10, 0])).tolist()
norm2 = (10 * np.random.randn(num,2) - np.array([0, 0])).tolist()
norm3 = (5 * np.random.randn(num,2) - np.array([-10, 0])).tolist()
norm4 = (20 * np.random.randn(num,2) - np.array([-5, 5])).tolist()
norm1.extend(norm2)
norm1.extend(norm3)
norm1.extend(norm4)
plt.figure()
for i in range(4):
plt.plot(np.array(norm1[100*i:100*(i+1)])[:,0],np.array(norm1[100*i:100*(i+1)])[:,1], cl[i])
return random.sample(norm1, num*4)
else:
return [[1,1], [2,1], [3,1], [2,3], [10,9], [9,9],[7,8],[6,5],[10,11], [3,3],[13,12], [13,10],[-5,-2],[-10,-9],[-8,-7]]
if __name__ == "__main__":
daSet = createDataSet("norm")
newCenterPoint, ind, iterNum = kMeans(daSet, 4)
print("Final: *****", newCenterPoint, ind, "\n\n iterative rounds:",iterNum)
plt.figure()
for ii, ele in enumerate(daSet):
plt.plot(ele[0], ele[1], cl[ ind[ii]]) #ind[ii]指明所属,利用cl[]区分不同属的个体;想要输出clu,可以argwhere查找ind==i
#lable cental points:
plt.plot(np.array(newCenterPoint)[:,0], np.array(newCenterPoint)[:,1], 'xr')
plt.show()