K-MEANS

import random
import math
from matplotlib import pyplot as plt

# 生成数据
DATA = [[random.randint(-100, -50) for _ in range(2)] for __ in range(100)]
for _ in range(100):
    DATA.append([random.randint(50, 100) for _ in range(2)])
random.shuffle(DATA)

EPS = 1e-2  # 允许误差
K = 2       # 分成几类

class K_MEANS:
    def __init__(self, k: int=1):
        self.k = k
        self.barycenter = [[random.randint(0, 100) for _ in range(2)] for __ in range(k)]
        self.DATA = [[0, 0]]

    def set_data(self, DATA: list) -> None:
        self.DATA = DATA
    
    def _get_dist(self, pid: int, kid: int) -> float:
        x1, y1 = self.DATA[pid][0], self.DATA[pid][1]
        x2, y2 = self.DATA[kid][0], self.DATA[kid][1]
        return math.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2)
    
    def _get_barycenter(self, points: list) -> list:
        sx, sy = 0, 0
        for pid in points:
            sx = sx + self.DATA[pid][0]
            sy = sy + self.DATA[pid][1]
        return [sx / len(self.DATA), sy / len(self.DATA)]
    
    def iterate(self) -> None:
        point_class = [[] for _ in range(self.k)]
        for pid in range(len(self.DATA)):
            min_dis, kid = 0x3f3f3f3f, -1
            for i in range(self.k):
                dis = self._get_dist(pid, i)
                if dis < min_dis:
                    min_dis = dis
                    kid = i
            point_class[kid].append(pid)
        for kid in range(self.k):
            self.barycenter[kid] = self._get_barycenter(point_class[kid])

    def get_parameter(self):
        return self.barycenter

def get_eps(pre: list, now: list) -> float:
    n = len(pre)
    m = len(pre[0])
    max_eps = 0
    for i in range(n):
        for j in range(m):
            if abs(pre[i][j] - now[i][j]) > max_eps:
                max_eps = abs(pre[i][j] - now[i][j])
    return max_eps

kmeans = K_MEANS(K)
kmeans.set_data(DATA)
last_para = [[random.randint(0, 100) for _ in range(2)] for __ in range(K)]
while get_eps(kmeans.get_parameter(), last_para) > EPS:
    kmeans.iterate()
    last_para = kmeans.get_parameter()
print(kmeans.get_parameter())
plt.plot(DATA, 'ob')
plt.show()
  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值