聚类
聚类是一种无监督的学习,它将相似的对象归到同一个簇中。聚类的方法几乎可以应用于所有对象,簇内的对象越相似,聚类的效果越好。K-means(K-均值聚类)算法使一种聚类算法。之所以称之为K-均值使因为它可以发现k个不同的簇,且每个簇的中心采用簇中所含值的均值计算而成。
K-均值聚类算法
K-均值是发现给定数据集的k个簇的算法。簇个数k是用户给定的,每一个簇通过其聚类中心,即簇中所有点的中心来描述。
K-均值算法的工作流程为——首先,随机确定k个初始点作为质心。然后将数据集中的每个点分配到一个簇中,具体来讲,为每个点找距其最近的聚类中心,将其分配给该聚类中心所对应的簇。这一步完成之后,每个簇的聚类中心更新为该簇所有点的平均值。
上述过程的伪代码表示如下:
创建k个点作为初始聚类中心(一般从样本中随机选择,我的代码里是硬编码为一个数组)
当任意一个点的簇分配结果发生改变时
对数据集中的每个数据点
对每个聚类中心
计算聚类中心与数据点之间的距离
将数据点分配到距其最近的簇
对每一个簇,计算簇中所有点的均值并将均值作为聚类中心
K-means的python实现:
# -*- coding: utf-8 -*-
from scipy.io import loadmat
from numpy import *
import matplotlib.pyplot as plt
def find_closest_centroids(centroids,data):
'''
找出每组数据最接近的聚类中心
Args:
centroids:聚类中心
data:数据集
Returns:
idx:数据集所属聚类中心的下标
'''
data_size = len(data)
idx = zeros((data_size,1))
K = len(centroids)
for i in range(data_size):
temp = sum(power(data[i] - centroids,2),1) # 求每组数据距离的平方,求距离省略了开方过程,大小关系不受影响
idx[i] = temp.argmin() # 数据集所属聚类中心的下标
return idx
def compute_centroids(data,idx,k):
'''
重新计算聚类中心
Args:
data:数据集
idx:数据集所属聚类中心的下标
k:聚类数量
Returns:
centroids:更新后的聚类中心
'''
data_size = len(data)
cluster_sum = zeros((k,2)) # 一个簇中所有数据之和,用于求平均值
cluster_data_size = zeros((k,1)) # 一个簇包含的数据长度,用于求平均值
for i in range(data_size): # 遍历每一组数据,计算cluster_sum和cluster_data_size
cluster_sum[int(idx[i])] = cluster_sum[int(idx[i])] + data[i]
cluster_data_size[int(idx[i])] = cluster_data_size[int(idx[i])] + 1
centroids = cluster_sum / cluster_data_size
return centroids
def figure_cluster(data,centroids,idx):
'''
在坐标轴中根据聚类画出数据样本,描出聚类中心
Args:
data:数据集
centroids:聚类中心
idx:数据集所属聚类中心的下标
Returns:
centroids:更新后的聚类中心
'''
data_size = len(data)
colors = ['r','b','y']
for i in range(data_size):
index = int(idx[i])
plt.scatter(data[i,0],data[i,1],marker='o',c=colors[index],s=15)
print(centroids[:,0],centroids[:,1])
plt.scatter(centroids[:,0],centroids[:,1],marker='x',color='000')
data_mat = loadmat("ex7data2.mat") # mat数据集
data = data_mat['X']
K = 3
initial_centroids = array([[3,3],[6,2],[8,5]]) # 硬编码初始聚类中心
centroids = initial_centroids
'''
迭代,当更新的聚类中心不再变化时结束
'''
while True:
old_centroids = centroids
idx = find_closest_centroids(centroids,data)
centroids = compute_centroids(data,idx,K)
if (old_centroids == centroids).all():
break
figure_cluster(data,centroids,idx)
PS:此处所用的数据集是吴恩达机器学习课程中的数据集,是mat类型的数据,可自行更改为自己的数据集