1.安装库
import numpy as np
import matplotlib.pyplot as plt
import random as rand
from mpl_toolkits.mplot3d import Axes3D
import random
2.定义函数
定义距离
定义初始聚类中心
附上标签
重新计算聚类中心
更新标签
检查标签是否需要再次更新
不断迭代
def Euclidean_Distance(xi,xj):
dist=np.sqrt(np.sum(np.square(xi-xj)))
return dist
def init(K,X):#随机生成初始原型向量
t = np.arange(0, K, 1)
p=random.sample(list(A),K)#随机抽取k个点
p=np.array(p)
p1=np.ones_like(p)
return p,t,p1
def mylabel(p, X, t): # 定义距离标签
print()
sign = np.zeros((len(X[:, 0]),1))
for i in range(len(X[:,0])):
d = Euclidean_Distance(p[0], X[i])
for j in range(len(p[:, 0])):
temp = Euclidean_Distance(p[j], X[i])
if temp <= d:
d = temp
sign[i] = t[j]
return sign
def update(p,X,t):
try0=mylabel(p,X,t)
location_t=[]
for i in range(len(t)):
location_t.append(np.where(try0 == t[i])[0])
p1=np.zeros_like(p)
for i in range(len(t)):
p1[i]=np.mean(X[location_t[i]],axis=0)
plt.clf()# 清除当前 figure 的所有axes,但是不关闭这个 window,所以能继续复用于其他的 plot
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(X[:, 0], X[:, 1], X[:, 2], c='grey')
ax.scatter(Y[:, 0], Y[:, 1], Y[:, 2], c='white')
ax.scatter(Z[:, 0], Z[:, 1], Z[:, 2], c='pink')
plt.pause(2)
return p1
def diff(p,p1):
diff=np.mat(p-p1)
diff_p=diff@diff.T
e=np.arange(0,len(p[:,0]))
for i in range(len(p[:,0])):
e[i] = np.sqrt(diff_p[i,i])
return e
def K_means(K,X):
p,t,p1 = init(K,X)
old_p=p1
interval=0
while max(diff(p,old_p))>0.0001:#设置聚类中心误差停止迭代的最大值
old_p=p
p=update(p,X,t)
interval+=1
return p
def random_norm(mu1, mu2, mu3, sigma1, sigma2, sigma3, r,size=1000):
res = []
chain = [mu1, mu2,mu3]
for i in range(1000 + size): # 这里的 1000 是为了让 Markov 链趋向极限分布
# 计算给定 Y = chain[1] 时 X 的边际分布
x = chain[0]
chain[1] = rand.normalvariate(
mu=mu2 + r * sigma2 / sigma1 * (x - mu1),
sigma=sigma2 ** 2 * (1 - r ** 2))
y = chain[1]
chain[0] = rand.normalvariate(
mu=mu1 + r * sigma1 / sigma2 * (y - mu2),
sigma=sigma1 ** 2 * (1 - r ** 2))
# 计算给定 X = chain[0] 时 Y 的边际分布
z=chain[2]
chain[2] = rand.normalvariate(
mu=mu3,
sigma=sigma3 ** 2)
if i >= 1000:
res.append(chain[:])
return res
主函数
画出真实标签下的三类点
k均值分类的三类点
if __name__=='__main__':
K=3
t=np.arange(1,K+1,1)
random.seed(45)
X = np.array(random_norm(0, 0, 0, 2, 1, 4, 0.5, size=1000))
Y = np.array(random_norm(20, 20, 20, 1, 2, 4, 0.5, size=1000))
Z = np.array(random_norm(40, 40, 40, 1, 2, 2, 0.5, size=1000))
A = np.vstack((X, Y, Z))
y1 = np.zeros((X.shape[0], 1))
y2 = np.ones((len(Y[:, 0]), 1))
y3 = np.ones((len(Z[:, 0]), 1)) * 2
label = np.vstack((y1, y2, y3))
p,location_t=K_means(K,A)
plt.clf()# 清除当前 figure 的所有axes,但是不关闭这个 window,所以能继续复用于其他的 plot
ax = plt.figure().add_subplot(111, projection='3d')
ax.scatter(X[:,0],X[:,1],X[:,2],c = 'w', marker = '^')
ax.scatter(Y[:,0],Y[:,1],Y[:,2],c = 'pink', marker = '^')
ax.scatter(Z[:,0],Z[:,1],Z[:,2],c = 'green', marker = '^')
ax.scatter(p[:,0],p[:,1],p[:,2],c = 'b', marker = '^')
ax.set_xlabel('X', fontdict={'size': 15, 'color': 'green'})
ax.set_ylabel('Y', fontdict={'size': 15, 'color': 'green'})
ax.set_zlabel('Z', fontdict={'size': 15, 'color': 'green'})
fig = plt.figure()
ax = Axes3D(fig)
ax.scatter(A[location_t[0], 0], A[location_t[0], 1],A[location_t[0], 2], c='grey')
ax.scatter(A[location_t[1], 0], A[location_t[1], 1],A[location_t[1], 2], c='white')
ax.scatter(A[location_t[2], 0], A[location_t[2], 1],A[location_t[2], 2], c='pink')
plt.show()