Kmeans算法是最常用的聚类算法。
主要思想是:在给定K值和K个初始类簇中心点的情况下,把每个点(亦即数据记录)分到离其最近的类簇中心点所代表的类簇中,所有点分配完毕之后,根据一个类簇内的所有点重新计算该类簇的中心点(取平均值),然后再迭代的进行分配点和更新类簇中心点的步骤,直至类簇中心点的变化很小,或者达到指定的迭代次数。
其训练数据的流程是:
根据上面的流程图来实现具体代码:
数据集提取链接
链接:https://pan.baidu.com/s/1hqS9BOfwICAZ9IfTz5_BCA
提取码:zlyi
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd
class KMeans:
def __init__(self, k, times):
self.k = k
self.times = times
self.cluster_centers = None
self.labels = None
def fit(self, X):
X = np.asarray(X)
# 通过设置随机种子可以保证每次运行得到的随机序列是相同的,随机种子也可以是别的数,不一定为0
np.random.seed(0)
# 随机选择聚类中心
self.cluster_centers = X[np.random.randint(0, len(X), self.k)]
# 初始化标签为0
self.labels = np.zeros(len(X))
for i in range(self.times):
for index, x in enumerate(X):
# 计算每个点与聚类中心的距离
dis = np.sqrt(np.sum((x - self.cluster_centers) ** 2, axis=1))
# 将该样本点的标签设置为欧式距离最近的样本点
self.labels[index] = dis.argmin()
# 根据已经初步分类的结果计算均值更新簇中心,self.k为簇中心的个数
for j in range(self.k):
self.cluster_centers[j] = np.mean(X[self.labels == j], axis=0)
def predict(self, X):
X = np.asarray(X)
result = np.zeros(len(X))
for index, x in enumerate(X):
dis = np.sqrt(np.sum((x - self.cluster_centers) ** 2, axis=1))
result[index] = dis.argmin()
return result
if __name__ == '__main__':
data = pd.read_csv('order.csv')
t = data.iloc[:, -8:]
kmeans = KMeans(3, 50)
# 可视化代码,因为多个特征不好实现可视化,遂选取2个特征来进行可视化的实现
t2 = t.loc[:, "Food%":"Fresh%"]
kmeans.fit(t2)
# 设置可视化中文的实现,字体
mpl.rcParams['font.family'] = 'SimHei'
mpl.rcParams['axes.unicode_minus'] = False
plt.figure(figsize=(10, 10))
# result = kmeans.predict([[30,30,40,0,0,0,0,0],[0,0,0,0,0,30,30,40],[30,30,0,0,0,0,20,20]])
# print(result) 预测的代码
# 画出各个簇的可视化图
plt.scatter(t2[kmeans.labels == 0].iloc[:, 0], t2[kmeans.labels == 0].iloc[:, 1], label='类别1')
plt.scatter(t2[kmeans.labels == 1].iloc[:, 0], t2[kmeans.labels == 1].iloc[:, 1], label='类别2')
plt.scatter(t2[kmeans.labels == 2].iloc[:, 0], t2[kmeans.labels == 2].iloc[:, 1], label='类别3')
# 簇中心用 + 画出
plt.scatter(kmeans.cluster_centers[:, 0], kmeans.cluster_centers[:, 1], marker='+', s=300)
plt.title('分析')
plt.xlabel('食物')
plt.ylabel('肉类')
plt.legend()
plt.show()