一、KMeans算法实现代码
class KMeans:
def __init__(self,data,k_clustres):
self.data = data
self.k_clustres = k_clustres
def train(self,max_iterations):
#随机选代表
center = self.center_init(self.data,self.k_clustres)
#计算最小距离
examples = self.data.shape[0]
#距离最小的index
closest_center_index = np.empty((examples,1))
#循环更新最小距离
for i in range(max_iterations):
#每一个样本点到K个中心点的距离,找到最近的
closest_center_index = self.center_find_closest(self.data,center)
#更新中心点
centers = self.center_update(self.data,closest_center_index,self.k_clustres)
return centers,closest_center_index
#选代表
def center_init(self,data,k):
examples = data.shape[0]
examples = np.arange(examples)
#打乱所有的样本
np.random.shuffle(examples)
random_index = examples
#选出初始代表
center = data[random_index[:k],:]
return center
#算距离
def center_find_closest(self,data,k_clustres):
examples = self.data.shape[0]
center_num = k_clustres.shape[0]
#初始化-->一个点到各中心点的距离向量
closest_center_index = np.zeros((examples,1))
#遍历每一个样本
for example_index in range(examples):
#每一个样本点与中心距离最小的那个中心的向量
distance = np.zeros((center_num,1))
#遍历每一个中心代表
for center_index in range(center_num):
distance_num = data[example_index,:] - k_clustres[center_index,:]
#当前样本点与当前中心的距离
distance[center_index] = np.sum(distance_num**2)
#遍历完各个中心点,把离簇中心点最近的那个
#样本点的距离记录到,与样本点一样index位置
closest_center_index[example_index] = np.argmin(distance)
return closest_center_index
#更新代表
def center_update(self,data,closest_center_index,k_clustres):
#样本点的特征数
num_features = data.shape[1]
#拼接
centerid = np.zeros((k_clustres,num_features))
#遍历每一个簇
for center_id in range(k_clustres):
closest_ids = closest_center_index == center_id
#选均值当新的代表
centerid[center_id] = np.mean(data[closest_ids.flatten(),:],axis=0)
return centerid
二、鸢尾花数据集测试
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
data = load_iris()
X = data['data']
class_name = data['target_names']
feature_name = data['feature_names']
y = data['target']
class_name
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
feature_name
['sepal length (cm)',
'sepal width (cm)',
'petal length (cm)',
'petal width (cm)']
data_ = np.concatenate((X,y.reshape(-1,1)),axis=1)
data_[:4,:]
array([[5.1, 3.5, 1.4, 0.2, 0. ],
[4.9, 3. , 1.4, 0.2, 0. ],
[4.7, 3.2, 1.3, 0.2, 0. ],
[4.6, 3.1, 1.5, 0.2, 0. ]])
三、未知label的数据分布
x_axis = 0
y_axis = 1
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
for label_index in range(len(class_name)):
plt.scatter(data_[:,x_axis][data_[:,4]==label_index],data_[:,y_axis][data_[:,4]==label_index])
plt.title("label know")
plt.legend(labels=class_name)
plt.subplot(1,2,2)
plt.scatter(data_[:,x_axis][:],data_[:,y_axis][:])
plt.title("label unknow")
plt.show()
四、已知lable和聚类划分的类别对比
x_train = data_[:,x_axis:y_axis+1]
#指定参数
k_clusteri = 3
max_iteration=50
k_means = KMeans(x_train,k_clusteri)
centers,closet_center_ids = k_means.train(max_iteration)
#对比聚类的结果
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
for label_index in range(len(class_name)):
plt.scatter(data_[:,x_axis][data_[:,4]==label_index],data_[:,y_axis][data_[:,4]==label_index])
plt.title("label know")
plt.legend(labels=class_name)
plt.subplot(1,2,2)
for center_id,centroid in enumerate(centers):
current_examples_index = (closet_center_ids == center_id).flatten()
plt.scatter(data_[:,x_axis][current_examples_index],data_[:,y_axis][current_examples_index],label=center_id)
for center_id,centroid in enumerate(centers):
plt.scatter(centroid[0],centroid[1],c='black',marker='x')
plt.legend(labels=class_name)
plt.title("label kmeans")
plt.show()
五、总结
从上图中可以明显的看出,kmeans的缺点了。在每次的聚类产生的划分都是不一样的
这表明它受随机初始化的初始簇中心影响较大,而且样本分布如左上图的这种,用简单的距离就很难把他们细分开了