KMeans算法(使用numpy)——python

博主近期利用numpy库对KMeans聚类算法进行了优化,详细介绍了改进过程。
摘要由CSDN通过智能技术生成

最近学习了numpy库,对之前的KMeans算法进行了一个改进。

import random
import copy
import matplotlib.pyplot as plt
import numpy as np
import time
# 该写法只在计算使用numpy,其余使用列表
class KMeans():
    def __init__(self,data,k=1):
        # 可以传入ndarry类型
        self.__data = data  # 存放输入点数据
        self.__k = k  # 中心点个数
        # 可以改为ndarry类型
        self.__centerPoint = []  # 中心点
        self.__result = []  # 对输入点进行分类的结果
        for i in range(k):
            self.__result.append([])   #[[] [] [] [] [] ]

    # 随机取k个中心点
    def randomCenterPoint(self):
        for i in range(self.__k):
            # 随机出k个下标
            index = random.randint(0,len(self.__data)-1)
            # 判断取出的数是否在中心点中  防止取重复数  使中心点重复
            # 不使用ndarry类型,因为tnp.append添加数据特别耗时间
            if self.__data[index] not in self.__centerPoint:
                self.__centerPoint.append(self.__data[index])
        pass

    # 把数据进行分类 计算各点到中心点的距离
    # data为点数据
    # center为中心点
    def calCenterPointDistance(self, data, center):
        centerDistance = []  # 用来记录每个点到k个中心点的距离[[k个数][k个数][k个数][k个数]...]长度为len(data)
        # 计算并存储各点到中心点距离
        center = np.array(center)
        data = np.array(data)
        for temp in data:
            centerDistance.append((np.sum((center-temp)**2,axis = 1)**0.5))
            pass

        self.__result = []
        # 因为对result进行的使append操作 及时清空 否则数据会堆积 越来越多
        # 对result清空操作
        for i in range(self.__k):
            self.__result.append([])
        m = 0
        # 根据各点到中心点距离  把数据点进行分类
        for temp in centerDistance:
            index = np.argmin(temp).tolist()
            self.__result[index].append(copy.deepcopy(data[m]))
            m += 1

    # 计算生成新的中心点
    def newCenterPoint(self, result):
        newCenterPoint = []  # 存放新的中心点
        # 转置矩阵  把各点x,y ,z...放在同一数组 方便计算
        for temp in result:
            temp = np.array(temp)
            temps = (np.sum(temp,axis=0)/len(temp)).tolist()
            newCenterPoint.append(copy.deepcopy(temps))
        return newCenterPoint

    #  计算新旧中心点之间的距离
    #  old代表原来的中心点列表  new新生成的中心点列表
    def calCenterPointToCenterPointDistance(self, old, new):
        old = np.array(old)
        new = np.array(new)
        res = np.sum((np.sum((new - old)**2,axis=1)**0.5))/len(old)
        return res  
        pass

    # 执行函数
    def fit(self, threshold):
        self.randomCenterPoint() # 随机中心点
        self.calCenterPointDistance(self.__data, self.__centerPoint)  # 把数据根据中心点分类
        newCenterPoint = self.newCenterPoint(self.__result)  # 生成新的中心点
        oldCenterPoint = self.__centerPoint  # 旧的中心点
        # 程序结束的条件
        while self.calCenterPointToCenterPointDistance(oldCenterPoint,newCenterPoint) > threshold:
            self.calCenterPointDistance(self.__data, newCenterPoint)  # 对data点数据进行新的分类
            oldCenterPoint = newCenterPoint  # 覆盖旧点
            newCenterPoint = self.newCenterPoint(self.__result)  # 生成新点

        self.__centerPoint = newCenterPoint
        return newCenterPoint, self.__result

if __name__ == "__main__":
    #  生成x个随机点
    a = time.time()
    data = [[random.randint(1, 100), random.randint(1, 100)]for i in range(10000)]
    km = KMeans(data, 6)
    center,result = km.fit(0.0001)

    # 数据可视化 可清楚看见点的分布情况
    plt.plot()
    plt.title("KMeans Classification")
    i = 0
    tempx = []
    tempy = []
    color = []
    for temp in result:
        temps = [[temp[x][i] for x in range(len(temp))] for i in range(len(temp[0]))]
        color += [i] * len(temps[0])
        tempx += temps[0]
        tempy += temps[1]
        i += 2
        pass
    plt.scatter(tempx, tempy, c=color, s=30)
    plt.show()
    b = time.time()
    print(b-a)
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值