简单实现K_means聚类分析

聚类分析

简单的理解下聚类分析,我们现在平面上生成一些随机点(x,y),每个点都有不同的位置,现在设定一个需求,就是将这些点分成K类,该怎么去分?

K_means

聚类分析有很多种方法,这里我们使用K_means方法简单实现,K_means的原理如下所示:

  1. 设置类的个数,即K的值,必须提前知道我们要分成几类。

  2. 在随即点位的内部,根据算法随机生成K类的中心点,当然,这个中心点位并不准确,需要多次迭代。

  3. 在平面范围内,依次判断每个点距离最近的中心点,并将该点设置为该中心点的类中。依次判断平面内的所有点,直至分完为止。

  4. 根据分好的K类,计算出每个类的中心点,这个计算出来的中心点,可能和之前随机生成的中心点不同,然后更新中心点,重复2的操作。

  5. 设置迭代次数,或者中心点位不再变化,误差降低到理想值,停止迭代。

举个栗子

我先用numpy随机生成一些点位坐标,并保存为txt文件

import numpy as np

with open('point.txt','a') as fp:
    for i in range(400):
        fp.write(str(np.random.rand()*100))
        fp.write(',')
        fp.write(str(np.random.rand() * 100))
        fp.write('\n')
    fp.close()

print('gen points successful!!')

可视化如下图:

在这里插入图片描述

聚类分析的具体代码如下:

import re
import numpy as np
import matplotlib.pyplot as plt
import time

class Start_Point():
    def __init__(self):
        self.x = []
        self.y = []

class Point(Start_Point):
    def __init__(self,path):
        super(Point, self).__init__()
        self.path = path
        self.point_list = []
        self.read_points()

    def read_points(self):
        pat = re.compile(',')
        with open(self.path,'r') as fp:
            for each in fp.readlines():
                point = (float(pat.split(each.strip())[0]),float(pat.split(each.strip())[1]))
                self.point_list.append(point)
                self.x.append(point[0])
                self.y.append(point[1])

class Mean_point(Start_Point):
    def __init__(self,K):
        super(Mean_point, self).__init__()
        self.classes = K
        self.mean_points = []
        self.gen_mean_points()

    def gen_mean_points(self):
        for i in range(self.classes):
            a,b = np.random.rand()*100,np.random.rand()*100
            self.x.append(a)
            self.y.append(b)
            self.mean_points.append((a,b))

    def calcuate(self,point):
        num = 0
        start = time.time()
        while True:
            # 一次分类
            temp_data_x = [[] for i in range(self.classes)]
            temp_data_y = [[] for i in range(self.classes)]

            for i in range(len(point.point_list)):
                # 自定义一个最大的loss,并且在每个点比较四次后,初始化
                loss = 450
                for j in range(self.classes):
                    temp = ((self.x[j] - point.x[i]) ** 2 + (self.y[j] - point.y[i]) ** 2) ** 0.5
                    if loss > temp:
                        loss = temp
                        index = j
                temp_data_x[index].append(point.x[i])
                temp_data_y[index].append(point.y[i])

            # 一次求散点的means值,求散点的均值可以优化一下算法,这个算法比较简单
            # 改变成迭代方式
            # new_means_x = []
            # new_means_y = []

            for i in range(len(temp_data_y)):
                means_x = 0
                means_y = 0
                for j in range(len(temp_data_y[i])):
                    means_x += temp_data_x[i][j]
                    means_y += temp_data_y[i][j]
                if len(temp_data_y[i]) != 0:
                    means_x = means_x / len(temp_data_x[i])
                    means_y = means_y / len(temp_data_y[i])
                # 更新means_point
                if self.x[i] != means_x and self.y[i] != means_y:
                    self.x[i] = means_x
                    self.y[i] = means_y
                else:
                    break

            # 设置迭代次数
            num += 1
            if num > 200:
                break

            # 绘制散点图
            if num % 20 == 0 :
                plt.figure(figsize=(6, 6))  # 图片像素大小
                # plt.scatter(data.X, data.Y, color="red")  # 散点图绘制
                plt.scatter(temp_data_x[0], temp_data_y[0], color='red')
                plt.scatter(temp_data_x[1], temp_data_y[1], color='yellow')
                plt.scatter(temp_data_x[2], temp_data_y[2], color='black')
                plt.scatter(temp_data_x[3], temp_data_y[3], color='pink')
                plt.scatter(self.x, self.y, color='blue')
                plt.legend()
                plt.show()  # 显示图片


        end = time.time()
        print('use time: {}'.format(end-start))


if __name__ == '__main__':
    mean_point = Mean_point(K=4)   #改变K的值,要在plt中修改显示
    mean_point.calcuate(point=Point(path='point.txt'))


迭代的效果如下所示:
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

注:改变K的值后,需要在plt.scatter()中加上或删除一个类的显示,比如K=5,就需要多加一个显示散点图,K=3,则删除一个散点图的显示,不然会报错。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值