代码:
import numpy as np
import scipy.io as sio
import matplotlib.pyplot as plt
# 获取数据
data1=sio.loadmat('ex7data2.mat')
print(data1.keys())
X=data1['X']#(300,2)
# 初始数据可视化
plt.scatter(X[:,0],X[:,-1])
plt.show()
# 获取每个样本所属类别
def find_sentroids(X,centros):
idx=[]
for i in range(len(X)):
#X[i]是一个(2,)的一维数组,centros是(k,2),然后有个扩充广播机制最后相减是一个(k,2)维度
dist=np.linalg.norm((X[i]-centros),axis=1)#(k,)
id_i=np.argmin(dist)
idx.append(id_i)
return np.array(idx)
centros=np.array([[3,3],[6,2],[8,5]])
idx=find_sentroids(X,centros)
# 计算聚类中心点
def compute_centros(X,idx,k):
centros=[]