K-MEANS概念及复现

K-MEANS概念及复现

1.相关概念

聚类概念:

  • 无监督问题:我们手里没有标签了
  • 聚类:相似的东西分到一组
  • 难点:如何评估,如何调参

如下图,可看出有三个类别分别对应三个不同的颜色,通俗的来说:此算法就是类似于一个投票,少数服从多数,大概就是假如一个男生在一群女生里面,那么他也会被分类成一个女生

在这里插入图片描述

基本概念:

  • 要得到簇的个数,需要指定K值(即K的取值对分类的结果存在影响)
  • 质心:均值,即向量各维度取平均值即可
  • 距离的度量:常用欧几里得距离和余弦相似度(注意需要标准化)
  • 优化目标

2.K-MEANS算法的优缺点

优势:

  • 简单,快速,适合常规数据集,因此,此算法比较适合处理那种已经分类好的数据集。

劣势:

  • K值难确定
  • 复杂度与样本呈线性关系
  • 很难发现任意形状的族
可视化展示,大家可以自己去搜索这个网站,去了解,非常简单清晰
https://www.naftaliharris.com/blog/visualizing-k-means-clustering/

3.代码实现

  • k_means.py
import numpy as np

class KMeans:
    def __init__(self, data, num_clustres):
        self.data = data  # 初始化数据
        self.num_clustres = num_clustres  # 初始化簇的数量

    def train(self, max_iterations):
        # 1.先随机选择K个中心点
        centroids = KMeans.centroids_init(self.data, self.num_clustres)  # 调用静态方法初始化中心点
        # 2.开始训练
        num_examples = self.data.shape[0]  # 获取样本数量
        closest_centroids_ids = np.empty((num_examples, 1))  # 初始化最近中心点的索引数组
        for _ in range(max_iterations):  # 迭代更新中心点
            # 3得到当前每一个样本点到K个中心点的距离,找到最近的
            closest_centroids_ids = KMeans.centroids_find_closest(self.data, centroids)  # 调用静态方法计算最近中心点的索引
            # 4.进行中心点位置更新
            centroids = KMeans.centroids_compute(self.data, closest_centroids_ids, self.num_clustres)  # 调用静态方法计算新的中心点
        return centroids, closest_centroids_ids  # 返回最终的中心点和最近中心点的索引

    @staticmethod
    def centroids_init(data, num_clustres):
        num_examples = data.shape[0]  # 获取样本数量
        random_ids = np.random.permutation(num_examples)  # 随机打乱样本顺序
        centroids = data[random_ids[:num_clustres], :]  # 选取前K个样本作为初始中心点
        return centroids  # 返回初始中心点

    @staticmethod
    def centroids_find_closest(data, centroids):
        num_examples = data.shape[0]  # 获取样本数量
        num_centroids = centroids.shape[0]  # 获取中心点数量
        closest_centroids_ids = np.zeros((num_examples, 1))  # 初始化最近中心点的索引数组
        for example_index in range(num_examples):  # 遍历每个样本
            distance = np.zeros((num_centroids, 1))  # 初始化距离数组
            for centroid_index in range(num_centroids):  # 遍历每个中心点
                distance_diff = data[example_index, :] - centroids[centroid_index, :]  # 计算样本与中心点之间的距离差
                distance[centroid_index] = np.sum(distance_diff ** 2)  # 计算距离平方和
            closest_centroids_ids[example_index] = np.argmin(distance)  # 找到距离最小的中心点的索引
        return closest_centroids_ids  # 返回最近中心点的索引

    @staticmethod
    def centroids_compute(data, closest_centroids_ids, num_clustres):
        num_features = data.shape[1]  # 获取特征数量
        centroids = np.zeros((num_clustres, num_features))  # 初始化中心点数组
        for centroid_id in range(num_clustres):  # 遍历每个中心点
            closest_ids = closest_centroids_ids == centroid_id  # 找到属于当前中心点的样本
            centroids[centroid_id] = np.mean(data[closest_ids.flatten(), :], axis=0)  # 计算当前中心点的坐标
        return centroids  # 返回中心点坐标
  • demo.py
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib
from k_means import KMeans

matplotlib.use('TkAgg')
data = pd.read_csv('./data/iris.csv')
iris_types = ['SETOSA', 'VERSICOLOR', 'VIRGINICA']

x_axis = 'petal_length'
y_axis = 'petal_width'

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
for iris_type in iris_types:
    plt.scatter(data[x_axis][data['class'] == iris_type], data[y_axis][data['class'] == iris_type], label=iris_type)
plt.title('label known')
plt.legend()

plt.subplot(1, 2, 2)
plt.scatter(data[x_axis][:], data[y_axis][:])
plt.title('label unknown')
plt.show()

num_examples = data.shape[0]
x_train = data[[x_axis, y_axis]].values.reshape(num_examples, 2)

# 指定好训练所需的参数
num_clusters = 3
max_iteritions = 50

k_means = KMeans(x_train, num_clusters)
centroids, closest_centroids_ids = k_means.train(max_iteritions)
# 对比结果
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
for iris_type in iris_types:
    plt.scatter(data[x_axis][data['class'] == iris_type], data[y_axis][data['class'] == iris_type], label=iris_type)
plt.title('label known')
plt.legend()

plt.subplot(1, 2, 2)
for centroid_id, centroid in enumerate(centroids):
    current_examples_index = (closest_centroids_ids == centroid_id).flatten()
    plt.scatter(data[x_axis][current_examples_index], data[y_axis][current_examples_index], label=centroid_id)

for centroid_id, centroid in enumerate(centroids):
    plt.scatter(centroid[0], centroid[1], c='black', marker='x')
plt.legend()
plt.title('label kmeans')
plt.show()

在demo.py软件运行时候,出现以下报错:

AttributeError: module 'backend_interagg' has no attribute 'FigureCanvas'. Did you mean: 'FigureCanvasAgg'

通过搜索,发现是matplotlib的backend的默认渲染器是agg,agg是一个没有图形显示界面的终端,如果要图像正常显示,则需要切换为图形界面显示的终端TkAgg

import matplotlib
matplotlib.use('TkAgg')

结果展示:

结果一:
在这里插入图片描述

结果二:(×代表质心)
在这里插入图片描述

  • 14
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

大杯奶绿加少冰

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值