Kmeans简单实现

一、KMeans算法实现代码

class KMeans:
    def __init__(self,data,k_clustres):
        self.data = data
        self.k_clustres = k_clustres
    
    def train(self,max_iterations):
        #随机选代表
        center = self.center_init(self.data,self.k_clustres)
        #计算最小距离
        examples = self.data.shape[0]
        #距离最小的index
        closest_center_index = np.empty((examples,1))
        #循环更新最小距离
        for i in range(max_iterations):
            #每一个样本点到K个中心点的距离,找到最近的
            closest_center_index = self.center_find_closest(self.data,center)
            #更新中心点
            centers = self.center_update(self.data,closest_center_index,self.k_clustres)
        return centers,closest_center_index
    
    #选代表
    def center_init(self,data,k):
        examples = data.shape[0]
        examples = np.arange(examples)
        #打乱所有的样本
        np.random.shuffle(examples)
        random_index = examples
        
        #选出初始代表
        center = data[random_index[:k],:]
        return center
    
    #算距离
    def center_find_closest(self,data,k_clustres):
        examples = self.data.shape[0]
        center_num = k_clustres.shape[0]
        #初始化-->一个点到各中心点的距离向量
        closest_center_index = np.zeros((examples,1))
        #遍历每一个样本
        for example_index in range(examples):
            #每一个样本点与中心距离最小的那个中心的向量
            distance = np.zeros((center_num,1))
            #遍历每一个中心代表
            for center_index in range(center_num):
                distance_num = data[example_index,:] - k_clustres[center_index,:]
                #当前样本点与当前中心的距离
                distance[center_index] = np.sum(distance_num**2)
            #遍历完各个中心点,把离簇中心点最近的那个
            #样本点的距离记录到,与样本点一样index位置
            closest_center_index[example_index] = np.argmin(distance)
        return closest_center_index
    
    #更新代表
    def center_update(self,data,closest_center_index,k_clustres):
        #样本点的特征数
        num_features = data.shape[1]
        #拼接
        centerid = np.zeros((k_clustres,num_features))
        #遍历每一个簇
        for center_id in range(k_clustres):
            closest_ids = closest_center_index == center_id
            #选均值当新的代表
            centerid[center_id] = np.mean(data[closest_ids.flatten(),:],axis=0)
        return centerid

二、鸢尾花数据集测试

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
data = load_iris()
X = data['data']
class_name = data['target_names']
feature_name = data['feature_names']
y = data['target']
class_name
array(['setosa', 'versicolor', 'virginica'], dtype='<U10')
feature_name
['sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)']
data_ = np.concatenate((X,y.reshape(-1,1)),axis=1)
data_[:4,:]
array([[5.1, 3.5, 1.4, 0.2, 0. ],
       [4.9, 3. , 1.4, 0.2, 0. ],
       [4.7, 3.2, 1.3, 0.2, 0. ],
       [4.6, 3.1, 1.5, 0.2, 0. ]])

三、未知label的数据分布


x_axis = 0
y_axis = 1
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
for label_index in range(len(class_name)):
    plt.scatter(data_[:,x_axis][data_[:,4]==label_index],data_[:,y_axis][data_[:,4]==label_index])
plt.title("label know")
plt.legend(labels=class_name)

plt.subplot(1,2,2)
plt.scatter(data_[:,x_axis][:],data_[:,y_axis][:])
plt.title("label unknow")
plt.show()

在这里插入图片描述

四、已知lable和聚类划分的类别对比

x_train = data_[:,x_axis:y_axis+1]
#指定参数
k_clusteri = 3
max_iteration=50
k_means = KMeans(x_train,k_clusteri)
centers,closet_center_ids = k_means.train(max_iteration)
#对比聚类的结果
plt.figure(figsize=(12,5))
plt.subplot(1,2,1)
for label_index in range(len(class_name)):
    plt.scatter(data_[:,x_axis][data_[:,4]==label_index],data_[:,y_axis][data_[:,4]==label_index])
plt.title("label know")
plt.legend(labels=class_name)

plt.subplot(1,2,2)
for center_id,centroid in enumerate(centers):
    current_examples_index = (closet_center_ids == center_id).flatten()
    plt.scatter(data_[:,x_axis][current_examples_index],data_[:,y_axis][current_examples_index],label=center_id)

for center_id,centroid in enumerate(centers):
    plt.scatter(centroid[0],centroid[1],c='black',marker='x')
plt.legend(labels=class_name)
plt.title("label kmeans")
plt.show()

在这里插入图片描述

五、总结

从上图中可以明显的看出,kmeans的缺点了。在每次的聚类产生的划分都是不一样的

这表明它受随机初始化的初始簇中心影响较大,而且样本分布如左上图的这种,用简单的距离就很难把他们细分开了

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

学AI不秃头

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值