题目
试编程实现k均值算法,设置三组不同的k值、三组不同初始中心点,在西瓜数据集4.0上进行试验比较,并讨论什么样的初始中心有利于取得好的结果。
代码
import numpy as np
import matplotlib.pyplot as plt
def createDataSet():
"""
创建测试的数据集,里面的数值中具有连续值
:return:
"""
dataSet = [
[0.697, 0.460],
[0.774, 0.376],
[0.634, 0.264],
[0.608, 0.318],
[0.556, 0.215],
[0.403, 0.237],
[0.481, 0.149],
[0.437, 0.211],
[0.666, 0.091],
[0.243, 0.267],
[0.245, 0.057],
[0.343, 0.099],
[0.639, 0.161],
[0.657, 0.198],
[0.360, 0.370],
[0.593, 0.042],
[0.719, 0.103],
[0.359, 0.188],
[0.339, 0.241],
[0.282, 0.257],
[0.748, 0.232],
[0.714, 0.346],
[0.483, 0.312],
[0.478, 0.437],
[0.525, 0.369],
[0.751, 0.489],
[0.532, 0.472],
[0.473, 0.376],
[0.725, 0.445],
[0.446, 0.459],
]
# 特征值列表
labels = ['密度', '含糖率']
for i in range(len(dataSet)):
if 9 <= i <= 21:
dataSet[i].append(-1)
else:
dataSet[i].append(1)
return np.array(dataSet), labels
# dataSet, labels = createDataSet()
# print(dataSet)
# print(labels)
def kMeans(dataSet, k):
"""
K均值算法
:param dataSet:
:param k: k个中心
:return:
"""
dataArr = dataSet[:, :-1]
# 初始k个均值向量
index = np.random.randint(0, len(dataSet), k)
mu = dataArr[index]
# 划分簇
run = True # run为false时停止循环
retCluster = {}
while run:
cluster = {}
for i in range(len(dataSet)):
minDist = np.inf
minIndex = -1
for j in range(len(mu)):
curDist = np.sqrt(((dataArr[i] - mu[j]) ** 2).sum())
if curDist < minDist:
minDist = curDist
minIndex = j
# 把第i个元素划入第j个簇中
if minIndex not in cluster.keys():
cluster[minIndex] = []
cluster[minIndex].append(i)
# 更新均值向量
cnt = 0 # 计算均值向量更新的数
for i in range(len(mu)):
data = np.array(dataArr[cluster[i]])
muHat = data.sum(axis=0) / len(data)
vecDist = np.sqrt(((mu[i] - muHat) ** 2).sum())
if vecDist != 0:
mu[i] = muHat
cnt += 1
if cnt == 0:
run = False
retCluster = cluster
return retCluster, mu
def main():
dataSet, labels = createDataSet()
cluster, mu = kMeans(dataSet, 3)
print(cluster)
print(mu)
for key in cluster.keys():
data = np.array(dataSet[cluster[key]])
plt.scatter(data[:, 0], data[:, 1], label=key)
plt.scatter(mu[:, 0], mu[:, 1], s=80, c='r', marker="+")
plt.legend()
plt.show()
if __name__ == '__main__':
main()