K-means 算法实现

K-means算法

最近用到了k-means算法,这里简单实现一下,和大家交流进步,介绍就不写了,需要的朋友看这里k-means维基百科

实现

这里是一个比较简单的实现,参考下图的伪代码。
有几点细节列出来

  1. 迭代过程中最初的中心点是随机生成的,所以即使相同的样本也不能保证每次的迭代结果都是一样的
  2. K的选择
    • 结合具体环境,样本分布等先验,知道是几个类别
    • 尝试不同的K值并比较确定最终的K值
    • 使用其他的聚类方法,如EM
  3. 聚类质量的评价
    使用类别最小外接圆形的直径同类间距离的比值
  4. K-means的局限性
    • 当类别分布不是圆形或者球形时,聚类效果相对较差
    • 对外点敏感

K-means 基本算法

# k-means clustering
import random
import time
import numpy as np
import matplotlib.pyplot as plt

def distance_l2(centroids, point):
    """
    计算当前样本同各个类别中心点的距离
    :param centroids: 样本中心点
    :param point: 当前样本
    :return: 当前样本所属类别
    """
    dist_dist = [None]*3
    for i, centroid in enumerate(centroids):
        dist = np.sqrt((centroid[0]-point[0])**2 + (centroid[1]-point[1])**2)
        dist_dist[i] = [dist, i]
    dist_dist.sort(key=lambda elem: elem[0])
    return dist_dist[0][1]


if __name__ == "__main__":
    random.seed()

    # 生成数据
    class_n = 3         # 类别
    class_i_size = 50   # 每个类别的样本个数
    centers = [[40, 60], [100, 80], [75, 50]]   # 生成数据时的样本中心点

    samples = [[]]*class_n
    icon = ['b.', 'r+', 'gx']
    fig = plt.figure()
    points_prop = {}
    for i in range(class_n):
        x, y = centers[i]
        for j in range(class_i_size):
            generateddata = (x + random.normalvariate(0, 8), y + random.normalvariate(0, 8))
            if not generateddata in samples[i]:  # 去掉重复数据
                samples[i].append(generateddata)
                plt.plot(generateddata[0], generateddata[1], icon[i])
    plt.show()

    max_iters = 10  # 最大迭代次数
    samples_shuttle = samples[0] + samples[1] + samples[2]  # 样本
    size = len(samples_shuttle)     # 样本大小

	# 聚类
    # 随机选择中心点
    centroids = random.sample(samples_shuttle, class_n)
    for i in range(max_iters):
        print('iter:\t', i)
        centroid_samples = [[], [], []]
        for k in range(class_n):
            plt.scatter(centroids[k][0], centroids[k][1], c=icon[k][0], s=100)  # 绘制当前中心点

        # 计算距离并归类
        for j in range(size):
            index = distance_l2(centroids, samples_shuttle[j])
            centroid_samples[index].append(samples_shuttle[j])
            plt.plot(samples_shuttle[j][0], samples_shuttle[j][1], icon[index])

        # 如果其中一类的样本个数太少,则重新随机生成中心点
        if True in (len(centroid_samples[0]) < 0.1*size,
                    len(centroid_samples[1]) < 0.1*size,
                    len(centroid_samples[2]) < 0.1*size):
            centroids = random.sample(samples_shuttle, class_n)
            i -= 1
            print('continue')
            continue

        # 根据重新聚类结果计算中心点
        static_flag = 0
        for m in range(class_n):
            aver_x, aver_y = 0., 0.
            size_m = len(centroid_samples[m])
            for p in centroid_samples[m]:
                aver_x += p[0]
                aver_y += p[1]
            centroid_temp = [aver_x/size_m, aver_y/size_m]
            if centroids[m] == centroid_temp:
                static_flag += 1
            centroids[m] = centroid_temp
        time.sleep(1)
        plt.show()

        # 当聚类中心点不再变化时,聚类收敛,直接退出
        if static_flag == class_n:
            print("converge")
            break
        if i == max_iters-1:
            # 一直迭代至最大次数退出
            print('exit at max_iters')
    print('end')

结果

这是代码的一次结果,较大的红绿蓝圆点是当前分布的中心点,第一张图是生成数据是各类别的分布情况,最后一张是聚类的结果,中间是聚类过程。
在这里插入图片描述

参考

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值