Kmeans算法

序言

kmeans是无监督学习算法之一,先随机选择k个聚类中心,然后每一个训练样本计算到k个聚类中心的距离,选择最近距离的一个中心将样本分配给他,最后对每一个聚类中心聚到一起的样本求均值得到新的聚类中心,新的聚类中心和旧的聚类中心之间做差,如果差小于我们设定的最小距离,那么就找到了聚类中心,没有就继续更新。

代码

import numpy as np
from copy import deepcopy

class Kmeans(object):
    """
    Kmeans Algorithm
    """
    def __init__(self, 
                 k: int=3, 
                 aev: float=0.0001, 
                 max_iter: int=300):
        """
        Args:
            k: cluster center number.
            ave: accept error value.
            max_iter: max iterate number.
        """
        self.k = k
        self.aev = aev
        self.max_iter = max_iter
        self.centers = {}

    def fit(self, data):
        """
        First: generate random cluster center.
        Second: calculate distance of per sample to center.
        Third: New cluster center.
        """
        # choice former k data as init center.
        for i in range(self.k):
            self.centers[i] = data[i]

        for epoch in range(self.max_iter):
            catagory_data = {}
            for i in range(self.k):
                catagory_data[i] = []
            for sample in data:
                distances = []
                for center_index, center_value in self.centers.items():
                    distance = np.sqrt((sample - center_value)**2) # distance
                    distances.append(distance.sum())
                catagory = distances.index(min(distances))
                catagory_data[catagory].append(sample)

            # update centers
            pre_centers = deepcopy(self.centers)
            for ik, category_sample in catagory_data.items():
                self.centers[ik] = np.average(category_sample, axis=0)

            end = False
            all_d = []
            for jk in range(self.k):
                pre_center = pre_centers[jk]
                update_center = self.centers[jk]
                d = np.sum(np.sqrt((sample - center_value)**2)/pre_center)
                all_d.append(d)

            if sum(all_d)/self.k < self.aev or epoch==self.max_iter-1:
                print("last loss:", sum(all_d)/self.k)
                print("end epoch:", epoch)
                end = True

            if end:
                break

    def get_centers(self):
        """
        return cluster result.
        """
        res = []
        for v in self.centers.values():
            res.append(list(v))
        print("finally cluster center:", res)
        return res

    def predict(self, pre_data):
        """
        predict sample.
        """
        distances = []
        for center_index, center_value in self.centers.items():
            distance = (pre_data - center_value)**2 # L2 distance
            distances.append(distance.sum())
        catagory = distances.index(min(distances))
        return catagory

if __name__ == "__main__":
    import matplotlib
    from matplotlib import pyplot as plt
    fig4=plt.figure()
    ax4=plt.axes(projection='3d')



    x = np.array([[1, 2, 3], 
                  [1.5, 1.8, 4], 
                  [5, 8, 5], 
                  [8, 8, 3], 
                  [1, 0.6, 7], 
                  [10, 6.6, 7], 
                  [13, 1.6, 4.7], 
                  [9, 11, 13]])

    kmeans = Kmeans(k=3)
    kmeans.fit(x)
    kmeans.get_centers()
    test_data = np.random.randint(0, 20, (1000, 3))
    colors = (0.3, 0.7, 1)
    catagorys = []
    for per_data in test_data:
        pre_res = kmeans.predict(per_data)
        c = colors[pre_res]
        dx, dy, dz = per_data[0], per_data[1], per_data[2]
        catagorys.append(c)
    X, Y, Z = test_data[:, 0], test_data[:, 1], test_data[:, 2]
    ax4.scatter(X, Y, Z, alpha=0.3, c=catagorys)
    plt.show()

结果

三维图画法可参考 (写的特别好):三维图画法
keans

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

超超爱AI

土豪请把你的零钱给我点

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值