Python实现Kmeans算法

# coding=UTF-8
import random
import matplotlib.pyplot as plt
import numpy as np
from numpy import mean, inf
from sklearn import datasets


#  Kemans算法实现
class Kmeans(object):
    """
    参数:
        k: 分类的类数
        n_iter: 最大拟合迭代次数
    属性:
        c_point:中心点位置
        assement:预测的样本分类
    """

    def __init__(self, k=3, n_iter=50):
        self.k = k
        self.n_iter = n_iter
        self.c_point = None
        self.assement = None

    # 拟合
    def fit(self, x):
        # 将数据集向量化
        x = np.array(x)
        # 将预测结果初始化
        self.assement = np.zeros(x.shape[0])
        # 随机初始化中心点
        self.c_point = self.SelectCenterPoint(x)
        
        # 存储最短距离
        min_dis = []
        for i in range(x.shape[0]):
            min_dis.append(inf)

        # 存储上一次迭代后属于每种中心点的点集合,共K个
        old_points = x

        count = 0
        while count < self.n_iter:
            # 计算每个点最近的中心点,并将其类别更新为中心点类别
            for i in range(x.shape[0]):
                for j in range(self.k):
                    now_dis = self.euc_dis(x[i], self.c_point[j])
                    if now_dis < min_dis[i]:
                        min_dis[i] = now_dis
                        self.assement[i] = j

            # 重新计算每个集合的中心点位置
            new_points = []
            for i in range(self.k):
                new_points.append([])
                for j in range(self.assement.shape[0]):
                    if self.assement[j] == i:
                        temp = x[j].tolist()
                        new_points[i].append(temp)
                self.c_point[i, :] = mean(new_points[i], axis=0)

            # 有时会出现某个中心点偏离数据集过远,造成类集合中无数据元素的现象,此时应重新初始化中心点
            if len(set(self.assement)) != self.k:
                self.c_point = self.SelectCenterPoint(x)
                count = 0
                continue

            # 当集合中元素不再更新时则可以提前结束迭代
            if self.setequal(old_points, new_points, self.k):
                break
            else:
                old_points = new_points
            count += 1
        return self

    # 随机初始化中心点
    def SelectCenterPoint(self, x):
        dimension = x.shape[1]
        points = np.zeros((self.k, dimension))
        for i in range(self.k):
            for j in range(dimension):
                x_min = np.min(x[:, j])
                x_max = np.max(x[:, j])
                points[i, j] = random.uniform(x_min, x_max)
        return points

    # 计算欧式距离
    def euc_dis(self, a, b):
        return np.sqrt(np.sum((a - b) ** 2))

    # 判断新旧集合中元素是否发生改变
    def setequal(self, a, b, k):
        if len(a) != k or len(b) != k:
            return False

        for i in range(k):
            if a[i] != b[i]:
                return False

        return True

    # 绘制最终结果
    def drawresult(self, x, y):
        x0 = []
        x1 = []
        x2 = []
        for i in range(y.shape[0]):
            if y[i] == 0:
                x0.append(x[i, :])
            elif y[i] == 1:
                x1.append(x[i, :])
            elif y[i] == 2:
                x2.append(x[i, :])
        x0 = np.array(x0)
        x1 = np.array(x1)
        x2 = np.array(x2)
        plt.scatter(self.c_point[:, 0], self.c_point[:, 1], c='black', alpha=0.9)
        plt.scatter(x0[:, 0], x0[:, 1], c='blue', alpha=0.5)
        plt.scatter(x1[:, 0], x1[:, 1], c='green', alpha=0.5)
        plt.scatter(x2[:, 0], x2[:, 1], c='red', alpha=0.5)
        plt.show()


if __name__ == '__main__':
    # 使用鸢尾花数据集作为测试数据集
    iris = datasets.load_iris()
    x = iris.data
    # 取其中两个维度数据便于画图
    x = x[:, [1, 2]]
    k = Kmeans()
    k.fit(x)
    k.drawresult(x, k.assement)

  • 0
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

羽路星尘

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值