之前介绍的几种算法,都是监督学习算法,我们需要对数据进行预处理,也就是在使用数据前,需要对数据集的样本数据进行标记。今天我们看一种无监督学习算法——k-means。
k-means算法用来实现聚类,什么是聚类?打一个比方,我们在袋子中放着各种水果,我们事先并不知道有哪几种,通过一些算法,我们可以借助于特性将水果聚集为几个类别,然后我们再去看这几个类别分别代表了什么水果。
k-means算法的思想非常简单,假设有m条数据,n个特性:
随机选取k个点作为起始中心(k行n列的矩阵,每个特征都有自己的中心);
遍历数据集中的每一条数据,计算它与每个中心的距离;
将数据分配到距离最近的中心所在的簇;
使用每个簇中的数据的均值作为新的簇中心
如果簇的组成点发生变化,则跳转执行第2步;否则,结束聚类。
影响k-means的因素主要是k的选取,比如,数据可以分为三类,但是我们的k选择为2,那么就会有一个类被划分进了一个错误的类。所以,我们需要多尝试一些k值。另外,初始k个中心的选择,也会影响算法的执行。下面看看《机器学习实战》中的算法实现。
首先是选取初始随机中心的函数,需要注意的是我们需要对每个中心的n个特性分别计算中心值:
def rand_cent(data_set, k):
n = np.shape(data_set)[1]
centroids = np.mat(np.zeros((k, n)))
for j in range(n):
min_j = np.min(data_set[:, j])
range_j = float(np.max(data_set[:, j]) - min_j)
centroids[:, j] = min_j + range_j * np.random.rand(k, 1)
return centroids
接下来,我们采用欧式距离计算中心的距离:
def dist_eclud(vec_a, vec_b):
return np.sqrt(np.sum(np.power(vec_a - vec_b, 2)))
下面是k-means算法的核心:
def kmeans(data_set, k, dist_meas=dist_eclud, create_cent=rand_cent):
m = np.shape(data_set)[0]
cluster_assment = np.mat(np.zeros((m, 2)))
centroids = create_cent(data_set, k)
cluster_changed = True
while cluster_changed:
cluster_changed = False
for i in range(m):
min_dist = np.inf
min_index = -1
for j in range(k):
dist_ji = dist_meas(centroids[j, :], data_set[i, :])
if dist_ji < min_dist:
min_dist = dist_ji
min_index = j
if cluster_assment[i, 0] != min_index:
cluster_changed = True
cluster_assment[i, :] = min_index, min_dist ** 2
#print(centroids)
for cent in range(k):
pts_in_cluster = data_set[np.nonzero(cluster_assment[:, 0].A == cent)[0]]
centroids[cent, :] = np.mean(pts_in_cluster, axis=0)
return centroids, cluster_assment
centroids返回中心的信息,cluster_assment返回了簇的信息,m行2列,m行对应m条样本数据,第一列保存了该行数据所属簇的index,第二列保存了该行到中心的距离,也就是偏离中心的误差。while中的内容就是上面2-5步骤做的事情。
下面看看使用,我伪造了一组数据,这些数据实际上可以被分到4类,边界也比较清晰,主要目的是为了看看算法的作用:
if __name__ == '__main__':
data_set = np.mat([
[0.5, 0.3], [0.2, 0.7], [0.8, 0.9],
[9.5, 0.3], [9.2, 0.7], [9.8, 0.9],
[0.5, 9.3], [0.2, 9.7], [0.8, 9.9],
[9.5, 9.3], [9.2, 9.7], [9.8, 9.9],
])
centroids, cluster_assment = kmeans(data_set, 4)
import matplotlib.pyplot as plt
print(centroids.A[:, 0])
plt.scatter(centroids.A[:, 0], centroids.A[:, 1], marker='x')
plt.scatter(data_set.A[:, 0], data_set.A[:, 1])
plt.show()
运行结果如下:
运行结果和我们的预期相符合。