使用K-means算法进行聚类分析

原理

对于最流行的聚类算法K-means算法,它的算法步骤如下:
1)从样本点中随机选择k个点作为初始簇中心。
2)将每个样本点划分到距离它最近的中心点 μ ( j ) \mu^{(j)} μ(j) j ∈ { 1 , ⋯   , k } j\in\{1,\cdots,k\} j{1,,k}所代表的簇中。
3)用各簇中所有样本的中心点代替原有的中心点。
4)重复步骤2和3,直到中心点不变或达到预定迭代次数时,算法终止。
K-means算法的目标函数为簇内误差平方和(within-cluster sum of squared errors, SSE),也称为簇惯性(cluster inertia)
S S E = ∑ i = 1 n ∑ j = 1 k w ( i , j ) ∣ ∣ x ( i ) − μ j ∣ ∣ 2 SSE=\sum_{i=1}^{n}\sum_{j=1}^{k}w^{(i,j)}||\bm{x^{(i)}}-\bm{\mu^{j}}||^2 SSE=i=1nj=1kw(i,j)x(i)μj2
x i \bm{x^{i}} xi属于簇 j j j,则 w ( i , j ) = 1 w^{(i,j)}=1 w(i,j)=1,否则为0。

K-means算法存在初始点选择不恰当使最终结果为局部最优解或收敛过慢的问题,可以使用K-means++算法进行改进让初始中心点彼此尽量远离,使用该该方法仅需将KMeansinit参数从random改为k-means++

代码与结果

代码引自《python机器学习》,如下所示:

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
X, y = make_blobs(n_samples=150, n_features=2, centers=3, cluster_std=0.5, shuffle=True, random_state=0)
plt.scatter(X[:,0], X[:,1], c='b', marker='o', s=50)# 将white改为其他的颜色
plt.grid()
plt.show()

km = KMeans(n_clusters=3, init='random', n_init=10, max_iter=300, tol=1e-04, random_state=0)
'''
设定簇数量为3,设置n_init=10,使程序能基于不同的随机初始中心点独立运行算法10次(跳过局部最优解),从中选择SSE最小的作为最终模型。
max_iter参数指定算法每轮运行的迭代次数。
'''
y_km = km.fit_predict(X)

plt.scatter(X[y_km == 0, 0], X[y_km == 0, 1], s=50, c='lightgreen', marker='s', label='cluster 1')
plt.scatter(X[y_km == 1, 0], X[y_km == 1, 1], s=50, c='orange', marker='o', label='cluster 2')
plt.scatter(X[y_km == 2, 0], X[y_km == 2, 1], s=50, c='lightblue', marker='v', label='cluster 3')
plt.scatter(km.cluster_centers_[:, 0],km.cluster_centers_[:, 1], s=250, c='red', marker='*', label='centroids')
plt.legend()
plt.grid()
plt.show()

初始数据集如下图所示:
在这里插入图片描述
聚类分析结果:
在这里插入图片描述

肘分析方法

通过肘分析方法,我们可以选定合适的K值,可以通过如下代码进行分析:

print('Distortion: %.2f' % km.inertia_)# 在完成KMeans模型的拟合后,簇内误差平方和可以通过inertia属性访问
distortions = []
for i in range(1,11) :
    km = KMeans(n_clusters=i, init='k-means++', n_init=10, max_iter=300, random_state=0)
    km.fit(X)
    distortions.append(km.inertia_)
plt.plot(range(1,11), distortions, marker='o',)
plt.xlabel('Number of clusters')
plt.ylabel('Distortion')
plt.show()

结果如图:
在这里插入图片描述
拐点在 K = 3 K=3 K=3处出现,与我们的初始设置是相符的。

  • 15
    点赞
  • 151
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值