【数据挖掘】数据挖掘经典算法之The K-means algorithm

1 篇文章 0 订阅
1 篇文章 0 订阅

聚类

  聚类是一种无监督的学习,它将相似的对象归到同一个簇中。聚类的方法几乎可以应用于所有对象,簇内的对象越相似,聚类的效果越好。K-means(K-均值聚类)算法使一种聚类算法。之所以称之为K-均值使因为它可以发现k个不同的簇,且每个簇的中心采用簇中所含值的均值计算而成。

K-均值聚类算法

  K-均值是发现给定数据集的k个簇的算法。簇个数k是用户给定的,每一个簇通过其聚类中心,即簇中所有点的中心来描述。
  K-均值算法的工作流程为——首先,随机确定k个初始点作为质心。然后将数据集中的每个点分配到一个簇中,具体来讲,为每个点找距其最近的聚类中心,将其分配给该聚类中心所对应的簇。这一步完成之后,每个簇的聚类中心更新为该簇所有点的平均值。
  上述过程的伪代码表示如下:

	创建k个点作为初始聚类中心(一般从样本中随机选择,我的代码里是硬编码为一个数组)
	当任意一个点的簇分配结果发生改变时
		对数据集中的每个数据点
			对每个聚类中心
				计算聚类中心与数据点之间的距离
			将数据点分配到距其最近的簇
		对每一个簇,计算簇中所有点的均值并将均值作为聚类中心

K-means的python实现:

# -*- coding: utf-8 -*-

from scipy.io import loadmat
from numpy import *
import matplotlib.pyplot as plt


def find_closest_centroids(centroids,data): 
    '''
    找出每组数据最接近的聚类中心
    Args:
        centroids:聚类中心
        data:数据集
    Returns:
        idx:数据集所属聚类中心的下标
    '''
    data_size = len(data)   
    idx = zeros((data_size,1))
    K = len(centroids)  
    
    for i in range(data_size):
        temp = sum(power(data[i] - centroids,2),1) # 求每组数据距离的平方,求距离省略了开方过程,大小关系不受影响
        idx[i] = temp.argmin() # 数据集所属聚类中心的下标

    return idx

def compute_centroids(data,idx,k):
    '''
    重新计算聚类中心
    Args:
        data:数据集
        idx:数据集所属聚类中心的下标
        k:聚类数量
    Returns:
        centroids:更新后的聚类中心
    '''
    data_size = len(data)
    cluster_sum = zeros((k,2)) # 一个簇中所有数据之和,用于求平均值
    cluster_data_size = zeros((k,1)) # 一个簇包含的数据长度,用于求平均值
    
    for i in range(data_size): # 遍历每一组数据,计算cluster_sum和cluster_data_size
        cluster_sum[int(idx[i])] = cluster_sum[int(idx[i])] + data[i] 
        cluster_data_size[int(idx[i])] = cluster_data_size[int(idx[i])] + 1
    
    centroids = cluster_sum / cluster_data_size
    
    return centroids

def figure_cluster(data,centroids,idx):
    '''
    在坐标轴中根据聚类画出数据样本,描出聚类中心
    
    Args:
        data:数据集
        centroids:聚类中心
        idx:数据集所属聚类中心的下标
    Returns:
        centroids:更新后的聚类中心
    '''
    data_size = len(data)
    colors = ['r','b','y'] 
    
    for i in range(data_size):
        index = int(idx[i])
        plt.scatter(data[i,0],data[i,1],marker='o',c=colors[index],s=15)
    
    print(centroids[:,0],centroids[:,1])
    plt.scatter(centroids[:,0],centroids[:,1],marker='x',color='000')


data_mat = loadmat("ex7data2.mat") # mat数据集
data = data_mat['X']

K = 3

initial_centroids = array([[3,3],[6,2],[8,5]]) # 硬编码初始聚类中心

centroids = initial_centroids

'''
迭代,当更新的聚类中心不再变化时结束
'''
while True:
    old_centroids = centroids
    idx = find_closest_centroids(centroids,data)
    centroids = compute_centroids(data,idx,K)

    if (old_centroids == centroids).all():
        break

figure_cluster(data,centroids,idx)

PS:此处所用的数据集是吴恩达机器学习课程中的数据集,是mat类型的数据,可自行更改为自己的数据集

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WGeeker

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值