机器学习(5)--K-means聚类(Clustering)算法

 K-means算法简述:
 1、 K-means算法是聚类(Clustering)中的经典算法,同时,也是数据挖掘的经典算法之一
 2、 该算法主要参数K,即在一些样本数据数,我们 不知道每个样本是什么类,但是我们知道全部的样本分为几类或是我们想把样本分为几类,这里的几类就是K
 3、本例基本步骤
    3.1 选取前K个样本,每个样本分为一类,并设置这个K样本的坐标为中心点
    3.2 计算所有每个样本与中心点的距离,这样得到每个中心点有哪些样本
    3.3 计算每个中心点所有样本的坐标的平均值,做为新的坐标
    3.4 循环3.2步聚,直至后一次的循环每个中心点所包含的样本不再发生变化时,退出循环


本文将通过matplotlib来显示中心点的变化与每个类的变化,如果你未安装matplotlib,可以屏蔽这几句相关的内容

程序运行时会跳出matplotlib窗体,并中断程序,关闭窗体后程序会继续执行。

# -*- coding:utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np

#下面两行,解决matplotlib中无法显示中文的问题
from pylab import *  
mpl.rcParams['font.sans-serif'] = ['SimHei']  

#数据定义,每一行为一个坐标,你也可以自行修改
data='''
1,1
2,1
4,5
6,6
5,4
3,3
2,2
3,2
5,6
1,3
3,1
6,5
''';

#整理数据
#[0,0]的意思在每个坐标后加两个值,每个样本数据均为4个值,[x,y,0,0]
#第1个为分类,
#第2个作用循环进行分类时为与前次分类比较,看是否发生变化,变化则为1,当所有样本均未变化时结束分类
data=[x.split(',')+[0,0] for x in data.split('\n')]
data=list(filter(lambda x: len(x)==4,data))
data=np.array(data).astype(np.float)
#print(data)
#取得中心点,选取前K个样本,每个样本分为一类

k=2  #在本中因为在matplotlib按类显示不同的点,只设了四个显示,分类数别太多了
centroids=data.copy()[:k,:-2]#-2表示只取样本的坐标


def draw(centroids,data,title):
    plt.axis([round((np.min(data,axis=0)-1)[0])
              ,round((np.max(data,axis=0)+1)[0])
              ,round((np.min(data,axis=0)-1)[1])
              ,round((np.max(data,axis=0)+1)[1])]) # 用于定义X,Y轴的范围
    plt.title(title)
    for index,center in enumerate(centroids):
        colorStr='rgby'[index:index+1] #在本中因为在matplotlib按类显示不同的点,只设了四个显示,分类数别太多了
        centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))
        if len(centerData)>0 :plt.scatter(centerData[:,0],centerData[:,1],c=colorStr) 
        plt.scatter(center[0],center[1],c=colorStr,marker='x') 
    plt.show()


runtimes=0
changePointLength=-1
while changePointLength!=0:
    runtimes+=1
    draw(centroids,data,'第 %d 次'%runtimes + ('  首次仅显示中心点,因为所有点的未分类' if runtimes==1 else ''))
    #3.2 计算所有每个样本与中心点的距离,这样得到每个中心点有哪些样本
    for dataItem in data:
        distances=np.sqrt(((centroids-dataItem[:-2])**2).sum(axis=1))#计算每个点与每个中心点的距离
        minDisType=np.argmin(distances)+1 #取得取小的距离的分类号
        if dataItem[-2]==minDisType:
            dataItem[-1]=0 #如果分类结果未发生变化
        else :
            dataItem[-1]=1 #如果分类结果发生变化
            dataItem[-2]=minDisType

    print(data)

    #3.3 计算每个中心点所有样本的坐标的平均值,做为新的坐标
    for index,center in enumerate(centroids):
        centerData=np.array(list(filter(lambda x:x[-2]==index+1 ,data)))[:,:-2] #得到每中心点包含哪些点
        centerData=centerData.mean(axis=0)
        center[0]=centerData[0]
        center[1]=centerData[1]
    #print(data)
    changePointLength=len(list(filter(lambda x:x[-1]==1 ,data))) #看有几个点的分类发生变化,如果为零,则退出循环


  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值