**
K-Means算法的python实现,并且给出数据集。
**
k-means算法简介
它是一种基础且常用的聚类算法,也就是对数据进行分类,这里以坐标数据分类为例子,具体原理可以百度。
话不多说,上代码。
from copy import deepcopy
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
plt.rcParams['figure.figsize'] = (16, 9)
plt.style.use('ggplot')
# 导入数据集
#这里根据自己数据文件的位置,注意要用///,而不是\\\,切记
data = pd.read_csv('E:/python/xclara.csv')
#print(data['V1'])
# 将csv文件中的数据转换为二维数组
f1 = data['V1'].values
f2 = data['V2'].values
X = np.array(list(zip(f1, f2)))
plt.scatter(f1, f2, c='black', s=6)
#plt.show()
# 按行的方式计算两个坐标点之间的距离
#其中的axis=0表示对矩阵的每一列求范数,axis=1表示对矩阵的每一行求范数,
#keeptime=True表示结果保留二维特性,keeptime=False表示结果不保留二维特性
def dist(a, b):
ax=1
return np.linalg.norm(a - b, axis=ax)
# 设定分区数
k = 3
# 随机获得中心点的X轴坐标
C_x = np.random.randint(0, int(np.max(X)-20), size=k)
# 随机获得中心点的Y轴坐标
C_y = np.random.randint(0, int(np.max(X)-20), size=k)
C = np.array(list(zip(C_x, C_y)), dtype=np.float32)
# 将初始化中心点画到输入的样例数据上
plt.scatter(f1, f2, c='black', s=7)
plt.scatter(C_x, C_y, marker='*', s=200, c='red')
#plt.show()
# 用于保存中心点更新前的坐标
C_old = np.zeros(C.shape)
print(C)
# 用于保存数据所属中心点
clusters = np.zeros(len(X))
# 迭代标识位,通过计算新旧中心点的距离
iteration_flag = dist(C, C_old)
tmp = 1
# 若中心点不再变化或循环次数不超过20次(此限制可取消),则退出循环
while iteration_flag.any() != 0 and tmp < 20:
# 循环计算出每个点对应的最近中心点
for i in range(len(X)):
# 计算出每个点与中心点的距离
distances = dist(X[i], C)
# print(distances)
# 记录0 - k-1个点中距离近的点
cluster = np.argmin(distances)
# 记录每个样例点与哪个中心点距离最近
clusters[i] = cluster
# 采用深拷贝将当前的中心点保存下来
# print("the distinct of clusters: ", set(clusters))
C_old = deepcopy(C)
# 从属于中心点放到一个数组中,然后按照列的方向取平均值
for i in range(k):
points = [X[j] for j in range(len(X)) if clusters[j] == i]
# print(points)
# print(np.mean(points, axis=0))
C[i] = np.mean(points, axis=0)
# print(C[i])
# print(C)
# 计算新旧节点的距离
print ('循环第%d次' % tmp)
tmp = tmp + 1
iteration_flag = dist(C, C_old)
print("新中心点与旧点的距离:", iteration_flag)
# 最终结果图示
colors = ['r', 'g', 'b', 'y', 'c', 'm']
fig, ax = plt.subplots()
# 不同的子集使用不同的颜色
for i in range(k):
points = np.array([X[j] for j in range(len(X)) if clusters[j] == i])
ax.scatter(points[:, 0], points[:, 1], s=7, c=colors[i])
ax.scatter(C[:, 0], C[:, 1], marker='*', s=200, c='black')
plt.show()
附上xclara.csv数据集
链接:https://pan.baidu.com/s/1imDldymuAkAWd53uZn_NoQ
提取码:s7wj
–来自百度网盘超级会员V3的分享