一、算法介绍
K-means算法是一种聚类算法,无监督算法
算法思想如下:
选择K个点做初始质心
repeat:
将每个点分配到距离最近的质心,形成K个簇
重新计算每个质心的位置
until
簇不发生变化或达到最大迭代次数
二、算法实现
距离度量采用欧氏距离
二维空间:
目标函数用误差平方和(Sum of the Squared Error,SSE)
K代表K个质心,Ci表示质心,dist表示欧氏距离
求得合适的Ci,对SSE求导,最终获得最佳质心是簇中各点的均值
三、Python实现
1. 读取数据
# 以矩阵的形式载入数据
from numpy import *
def loadDataSet(fileName):
dataSet = []
f = open(fileName)
for line in f.readlines():
curLine = line.strip().split(' ')
#print(curLine)
fltLine = list(map(float, curLine))
dataSet.append(fltLine)
return mat(dataSet)
2. 向量矩阵
def distEclud(vecA, vecB):
return sqrt(sum(power(vecA - vecB, 2)))
3. 选取初始质心
#生成k个质心
def randCent(dataSet, k):#dataSet是数据集,k是质心个数
n = shape(dataSet)[1] #n是列数
centroids = mat(zeros((k, n)))#生成零矩阵,k行n列
for j in range(n):
minJ = min(dataSet[:, j]) #找到第j列最小值,
rangeJ = float(max(dataSet[:, j]) - minJ) #求第j列最大值与最小值的差
centroids[:, j] = minJ + rangeJ * random.rand(k, 1) #生成介于最大值与最小值之间的质心列
return centroids
4. 算法实现
def KMeans(dataSet, k, distMeas=distEclud, createCent=randCent):
m = shape(dataSet)[0] #数据集的行
clusterAssment = mat(zeros((m, 2))) #记录每个数据所属的质心,及距离质心的距离
centroids = createCent(dataSet, k) #初始质心
clusterChanged = True
while clusterChanged:
clusterChanged = False
for i in range(m): #遍历数据集中的每一行数据
minDist = inf;minIndex = -1 #INF正无穷
for j in range(k): #遍历k个质心
distJI = distMeas(centroids[j, :], dataSet[i, :])
if distJI < minDist: #更新最小距离和质心下标
minDist = distJI; minIndex = j
if clusterAssment[i, 0] != minIndex:
clusterChanged = True
clusterAssment[i, :] = minIndex, minDist**2 #记录最小距离质心下标,最小距离的平方
print(centroids)
for cent in range(k): #更新质心位置
ptsInClust = dataSet[nonzero(clusterAssment[:,0].A==cent)[0]] #获得距离同一个质心最近的所有点的下标,即同一簇的坐标
centroids[cent,:] = mean(ptsInClust, axis=0) #求同一簇的坐标平均值,axis=0表示按列求均值
return centroids, clusterAssment
5. 做图函数
import matplotlib.pyplot as plt
def draw(data,center):
length=len(center)
fig=plt.figure
# 绘制原始数据的散点图
plt.scatter(data[:,0].tolist(),data[:,1].tolist(),s=25,alpha=0.4)
# 绘制簇的质心点
for i in range(length):
plt.annotate('center',xy=(center[i,0],center[i,1]),xytext=\
(center[i,0]+1,center[i,1]+1),arrowprops=dict(facecolor='red'))
plt.show()
6. 验证
dat=mat(loadDataSet('/Users/cy_ariel/Desktop/test.txt'))
center,clust=KMeans(dat,2,distEclud,randCent)
draw(dat,center)