原理
对于最流行的聚类算法K-means算法,它的算法步骤如下:
1)从样本点中随机选择k个点作为初始簇中心。
2)将每个样本点划分到距离它最近的中心点
μ
(
j
)
\mu^{(j)}
μ(j),
j
∈
{
1
,
⋯
,
k
}
j\in\{1,\cdots,k\}
j∈{1,⋯,k}所代表的簇中。
3)用各簇中所有样本的中心点代替原有的中心点。
4)重复步骤2和3,直到中心点不变或达到预定迭代次数时,算法终止。
K-means算法的目标函数为簇内误差平方和(within-cluster sum of squared errors, SSE),也称为簇惯性(cluster inertia)
S
S
E
=
∑
i
=
1
n
∑
j
=
1
k
w
(
i
,
j
)
∣
∣
x
(
i
)
−
μ
j
∣
∣
2
SSE=\sum_{i=1}^{n}\sum_{j=1}^{k}w^{(i,j)}||\bm{x^{(i)}}-\bm{\mu^{j}}||^2
SSE=i=1∑nj=1∑kw(i,j)∣∣x(i)−μj∣∣2
若
x
i
\bm{x^{i}}
xi属于簇
j
j
j,则
w
(
i
,
j
)
=
1
w^{(i,j)}=1
w(i,j)=1,否则为0。
K-means算法存在初始点选择不恰当使最终结果为局部最优解或收敛过慢的问题,可以使用K-means++算法进行改进让初始中心点彼此尽量远离,使用该该方法仅需将KMeans
的init
参数从random
改为k-means++
。
代码与结果
代码引自《python机器学习》,如下所示:
from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
X, y = make_blobs(n_samples=150, n_features=2, centers=3, cluster_std=0.5, shuffle=True, random_state=0)
plt.scatter(X[:,0], X[:,1], c='b', marker='o', s=50)# 将white改为其他的颜色
plt.grid()
plt.show()
km = KMeans(n_clusters=3, init='random', n_init=10, max_iter=300, tol=1e-04, random_state=0)
'''
设定簇数量为3,设置n_init=10,使程序能基于不同的随机初始中心点独立运行算法10次(跳过局部最优解),从中选择SSE最小的作为最终模型。
max_iter参数指定算法每轮运行的迭代次数。
'''
y_km = km.fit_predict(X)
plt.scatter(X[y_km == 0, 0], X[y_km == 0, 1], s=50, c='lightgreen', marker='s', label='cluster 1')
plt.scatter(X[y_km == 1, 0], X[y_km == 1, 1], s=50, c='orange', marker='o', label='cluster 2')
plt.scatter(X[y_km == 2, 0], X[y_km == 2, 1], s=50, c='lightblue', marker='v', label='cluster 3')
plt.scatter(km.cluster_centers_[:, 0],km.cluster_centers_[:, 1], s=250, c='red', marker='*', label='centroids')
plt.legend()
plt.grid()
plt.show()
初始数据集如下图所示:
聚类分析结果:
肘分析方法
通过肘分析方法,我们可以选定合适的K值,可以通过如下代码进行分析:
print('Distortion: %.2f' % km.inertia_)# 在完成KMeans模型的拟合后,簇内误差平方和可以通过inertia属性访问
distortions = []
for i in range(1,11) :
km = KMeans(n_clusters=i, init='k-means++', n_init=10, max_iter=300, random_state=0)
km.fit(X)
distortions.append(km.inertia_)
plt.plot(range(1,11), distortions, marker='o',)
plt.xlabel('Number of clusters')
plt.ylabel('Distortion')
plt.show()
结果如图:
拐点在
K
=
3
K=3
K=3处出现,与我们的初始设置是相符的。