numpy实现k-means算法

numpy实现k-means算法


看numpy的中文文档学了一些
尝试了一下他的k-means代码
有些地方在本地运行有小问题
改了一下,然后加了详细注释帮助理解
原代码的地址是: NumPy 中文文档

import matplotlib.pyplot as plt
import numpy as np
import random
import math
import re

def euclDistance(v1, v2):
    #计算两点间的距离
    return math.sqrt(sum(np.power(v2 - v1, 2)))

def initCentroids(dataSet, k):#dataSet-数据点数组 k-设置的质心数
    #初始化质心 
    numSamples, dim = dataSet.shape#numSample-数据点个数 dim-数据点维数 
    #shape返回一个关于数组长宽的数组
    centroids = np.zeros((k, dim))#centroids-存放质心的数组
    for i in range(k):
        index = int(random.uniform(0, numSamples))#index-在零到数据点个数间的随机数
        centroids[i, :] = dataSet[index, :]
        #将随机质心存储入centroids
        """----------------------------------------疑问:为什么random出的随机数总是数据点的其中一个"""
    return centroids

def kmeans(dataSet, k):
    #k-means算法的核心函数
    numSamples = dataSet.shape[0]#数据点个数为数据点数组的行数
    clusterAssment = np.mat(np.zeros((numSamples, 2)))#clusterAssment-存储数据点集的矩阵
    #1.zeros生成一个行数numSamples,列数2的数组
    #2.mat将1生成的数组转换成矩阵
    clusterChanged = True#clusterChanged-表示是否需要重新分组的布尔值判定量
    
    centroids = initCentroids(dataSet, k)#初始化质心
    
    while clusterChanged:#需要重新分组时
        clusterChanged = False#重置判定量为假
        for i in range(numSamples):#遍历所有数据点
            minDist = 100000.0#minDist-最小的数据点与质心的距离
            minIndex = 0#minIndex-最小的链接地址
            for j in range(k):
                #计算每个数据点到哪个质心的距离最小,及记录是哪一个质心
                distance = euclDistance(centroids[j, :], dataSet[i, :])#distance-暂时存放数据点到质心的距离
                if distance < minDist:
                    minDist = distance
                    minIndex = j
            if clusterAssment[i, 0] != minIndex:#当该数据点所隶属的质心与最小链接地址不同时更新点中的数据
                clusterChanged = True#重置判定量为真
                clusterAssment[i, :] = minIndex, minDist**2#该数据点的第二列变为一个数组,又隶属的质心链接与最短距离的平方组成
        for j in range(k):#由新的隶属关系中更新质心位置
            pointsInCluster = dataSet[np.nonzero(clusterAssment[:, 0].A == j)[0]]#形成一个包含所有数据点最小距离的矩阵?
            """----------------------------------------------------疑问:具体过程不明"""
            centroids[j, :] = np.mean(pointsInCluster, axis = 0)
            #np.mean()-求取矩阵均值 axis=0-对矩阵每列求均值
    print("分类完成")
    return centroids,clusterAssment

def showCluster(dataSet, k, centroids, clusterAssment):
	#数据可视化
    numSamples, dim = dataSet.shape
    if dim != 2:
        print("无法处理非二维数据")
        return 1
    mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']
    if k > len(mark):
        print("k的值太大了,表现不出来了")
        return 1
    for i in range(numSamples):#画出所有数据点
        markIndex = int(clusterAssment[i, 0])
        plt.plot(dataSet[i, 0], dataSet[i, 1], mark[markIndex])
    mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db', '<b', 'pb']
    for i in range(k):
        plt.plot(centroids[i, 0], centroids[i, 1], markersize = 12)
            
    plt.show()
    
print("加载数据")
dataSet = []
file = open("testSet.txt")
for line in file.readlines():
    #从文件取出每行 为一字符串
    lineArr = line.strip().split('    ')
    #删除首尾空格 将字符串从中间的空格处分开
    s1 = re.search('-?\d+.\d+', lineArr[0], re.M|re.I).group()
    s2 = re.search('-?\d+.\d+', lineArr[1], re.M|re.I).group()
    #使用正则取出字符串中的浮点数字符串
    dataSet.append([float(s1), float(s2)])#将数据点存入数组
dataSet = np.mat(dataSet)#将数据点数组转化为矩阵
k = 4
centroids, clusterAssment = kmeans(dataSet, k)
showCluster(dataSet, k, centroids, clusterAssment)

虽然大致是明白了,但是其中有两处代码还是不知道实现过程 ╯_╰
如果有大佬知道,请帮忙解释一下(´v`)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值