一、算法原理
聚类(Clustering)是一种典型的非监督机器学习任务,用于将无标签的输入数据按照一定的特征来区分为不同的类别。与分类(Classification)相比,其不会生成有意义的类别标签。
比如根据形状,对由若干片点云组成的数据集(“猫”、“狗”、“人”)进行聚类,不考虑错误,也只能得到三个类别(1,2,3),这三个类别只表示这个数据集在形状这一标准下可以被分成这三种不同类别,类别“1”有可能是“猫”、“狗”、“人”中的任何一种。因此在完成聚类后还需通过其他方法来验证分出的类别的具体含义,以及聚类操作是否有效。
聚类的方法有很多,可以基于划分(K-means、K-plane、K-Medoids等)、基于密度(DBSCAN、HDBSCAN等)、基于层次(BIRCH、CURE等)等多种方法进行。而其中,基于欧式距离划分的K-means聚类算法又是最基本的一种。
K-means算法的聚类思想是:处于同一类内的点,到该类聚类中心的距离应当尽可能小,而处于不同类内的点,到其他聚类中心的距离应当尽可能大。
二、执行过程
K-means算法的执行过程为:
首先,指定数据集划分的类数K,随机在样本空间中选取K个点作为初始聚类中心;
然后,计算每个点,到K个聚类中心的距离,将该点分配给离它最近的聚类中心,遍历所有数据点后,形成了初始的聚类点云;
然后,对每一类的聚类点云取质心,作为新的聚类中心;
然后,再次计算每个点到新的K个聚类中心的距离,同样将该点分配给离它最近的聚类中心。遍历后,形成新的聚类点云;
对新的聚类点云取质心,再形成下一轮迭代的聚类中心,一直迭代下去。可以指定迭代次数,也可以指定质心不再变化或变化很小时,终止执行,输出聚类结果。
三、代码实现
整体代码分为四个模块,分别用于初始化聚类中心、计算所有点到聚类中心的距离、更新聚类中心和返回聚类结果。
import numpy as np
from scipy.spatial import KDTree
import pandas as pd
import matplotlib.pyplot as plt
import time
class K_means:
def __init__(self, data, k, max_iter):
self.data = data
self.k = k
self.max_iter = max_iter
def fit(self):
centroids =K_means.centroids_init(self.data, self.k)
for _ in range(self.max_iter):
cluster_labels = K_means.get_clusters(self.data, centroids)
new_centroids = K_means.update_centroids(self.data, cluster_labels, k)
centroids = new_centroids
return cluster_labels, centroids
def centroids_init(data, k):
centers = data[np.random.choice(data.shape[0], k, replace=False)]
return centers
def get_clusters(data, centroids):
distances = np.linalg.norm(data[:, np.newaxis] - centroids, axis=2)
cluster_labels = np.argmin(distances, axis=1)
return cluster_labels
def update_centroids(data, cluster_labels, k):
new_centroids = np.array([data[cluster_labels == i].mean(axis=0) for i in range(k)])
return new_centroids
测试代码,使用sklearn库内置的尾花数据集,也可以自定义数据集:
from sklearn.datasets import load_iris
iris_dataset = load_iris()
data=iris_dataset['data']
fig = plt.figure()
ax = fig.add_subplot(121,projection='3d')
ax1 = fig.add_subplot(122,projection='3d')
k = 3 # 类别数
T = 100 # 最大迭代数
test=K_means(data, k,T)
labels, centers = test.fit()
print(labels)
ax.scatter(data[:,0],data[:,1],data[:,2],c=data[:,3],s=3)
ax1.scatter(data[:,0],data[:,1],data[:,2],c=labels,s=3)
plt.show()