数据挖掘十大算法(十):K-means聚类算法

本文详细介绍了K-means算法的原理,包括基本算法步骤、初始中心点的选取策略以及误差平方和(SSE)的概念。还讨论了k值的确定方法,并提供了Python实现和sklearn库的应用示例。
摘要由CSDN通过智能技术生成

一,K-means算法原理

基本算法

K-means算法是最常用的一种聚类算法。算法的输入为一个样本集(或者称为点集),通过该算法可以将样本进行聚类,具有相似特征的样本聚为一类。
算法步骤:
step1:选定要聚类的类别数目k,同时选定初始中心点
step2:寻找组织,将每一个样本点分给k个中心点(根据距离)
step3:重新计算新的中心点
step4:判断中心点是否发生变化,若变化则重复,否则break

初始中心点的选取

初始中心点的选取,对聚类的结果影响较大。可以验证,不同初始中心点,会导致聚类的效果不同。如何选择初始中心点呢?一个原则是:

初始中心点之间的间距应该较大。因此,可以采取的策略是:

step1:计算所有样本点之间的距离,选择距离最大的一个点对(两个样本C1, C2)作为2个初始中心点,从样本点集中去掉这两个点。

step2:如果初始中心点个数达到k个,则终止。如果没有,在剩余的样本点中,选一个点C3,这个点优化的目标是:
在这里插入图片描述
这是一个双目标优化问题,可以约束其中一个,极值化另外一个,这样可以选择一个合适的C3点,作为第3个初始中心点。

如果要寻找第4个初始中心点,思路和寻找第3个初始中心点是相同的。

误差平方和(Sum of Squared Error)

误差平法和,SSE,用于评价聚类的结果的好坏,SSE的定义如下。
在这里插入图片描述
一般情况下,k越大,SSE越小。假设k=N=样本个数,那么每个点自成一类,那么每个类的中心点为这个类中的唯一一个点本身,那么SSE=0。

k值的确定

一般k不会很大,大概在2~10之间,因此可以作出这个范围内的SSE-k的曲线,再选择一个拐点,作为合适的k值。
在这里插入图片描述
可以看到,k=5之后,SSE下降的变得很缓慢了,因此最佳的k值为5。

二,基本原理的Python实现

# K-means Algorithm is a clustering algorithm
import numpy as np
import matplotlib.pyplot as plt
import random


def get_distance(p1, p2):
    diff = [x - y for x, y in zip(p1, p2)]
    distance = np.sqrt(sum(map(lambda x: x ** 2, diff)))
    return distance


# 计算多个点的中心
def calc_center_point(cluster):
    N = len(cluster)

    m = np.array(cluster).transpose().tolist()  # m的shape是(2, N)

    center_point = [sum(x) / N for x in m]  # 这里其实就是分别对x,y求平均
    return center_point


# 检查两个点是否有差别
def check_center_diff(center, new_center):
    n = len(center)
    for c, nc in zip(center, new_center):
        if c != nc:
            return False
    return True


# K-means算法的实现
def K_means(points, center_points):
    N = len(points)  # 样本个数
    n = len(points[0])  # 单个样本的维度
    k = len(center_points)  # k值大小

    tot = 0
    while True:  # 迭代
        temp_center_points = []  # 记录中心点

        clusters = []  # 记录聚类的结果
        for c in range(0, k):
            clusters.append([])  # 初始化

        # 针对每个点,寻找距离其最近的中心点(寻找组织)
        for i, data in enumerate(points):
            distances = []
            for center_point in center_points:
                distances.append(get_distance(data, center_point))

            index = distances.index(min(distances))  # 找到最小的距离的那个中心点的索引,
            clusters[index].append(data)  # 那么这个中心点代表的簇,里面增加一个样本(要理解这里)


        tot += 1
        print('Epoch:{} Clusters:{}'.format(tot, len(clusters)))
        k = len(clusters)
        colors = ['r.', 'g.', 'b.', 'k.', 'y.']  # 颜色和点的样式
        for i, cluster in enumerate(clusters):
            data = np.array(cluster)
            data_x = [x[0] for x in data]
            data_y = [x[1] for x in data]
            plt.subplot(2, 
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值