注意:单击此处https://urlify.cn/BVbiUv下载完整的示例代码,或通过Binder在浏览器中运行此示例
在此示例中,我们从运行时间和聚类质量等方面来比较K-means的各种初始化策略。
同时我们还应用了不同的聚类质量度量,来判断聚类标签对真实类标签(ground truth)的适合程度。
聚类质量的评估度量指标(有关度量指标的定义和讨论,请参阅聚类性能评估):
简称 | 全名 |
---|---|
homo | homogeneity得分(homogeneity score) |
compl | completeness得分(completeness score) |
v-meas | V度量(V measure) |
ARI | 调整后的Rand指数(adjusted Rand index) |
AMI | 调整后的共同信息(adjusted mutual information) |
silhouette | 轮廓系数(silhouette coefficient) |
输出:
n_digits: 10, n_samples 1797, n_features 64__________________________________________________________________________________init time inertia homo compl v-meas ARI AMI silhouettek-means++ 0.24s 69510 0.610 0.657 0.633 0.481 0.629 0.129random 0.24s 69907 0.633 0.674 0.653 0.518 0.649 0.131PCA-based 0.03s 70768 0.668 0.695 0.681 0.558 0.678 0.142__________________________________________________________________________________
print(__doc__)from time import timeimport numpy as npimport matplotlib.pyplot as pltfrom sklearn import metricsfrom sklearn.cluster import KMeansfrom sklearn.datasets import load_digitsfrom sklearn.decomposition import PCAfrom sklearn.preprocessing import scalenp.random.seed(42)X_digits, y_digits = load_digits(return_X_y=True)data = scale(X_digits)n_samples, n_features = data.shapen_digits = len(np.unique(y_digits))labels = y_digitssample_size = 300print("n_digits: %d, \t n_samples %d, \t n_features %d" % (n_digits, n_samples, n_features))print(82 * '_')print('init\t\ttime\tinertia\thomo\tcompl\tv-meas\tARI\tAMI\tsilhouette')def bench_k_means(estimator, name, data): t0 = time() estimator.fit(data) print('%-9s\t%.2fs\t%i\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f\t%.3f' % (name, (time() - t0), estimator.inertia_, metrics.homogeneity_score(labels, estimator.labels_), metrics.completeness_score(labels, estimator.labels_), metrics.v_measure_score(labels, estimator.labels_), metrics.adjusted_rand_score(labels, estimator.labels_), metrics.adjusted_mutual_info_score(labels, estimator.labels_), metrics.silhouette_score(data, estimator.labels_, metric='euclidean', sample_size=sample_size)))bench_k_means(KMeans(init='k-means++', n_clusters=n_digits, n_init=10), name="k-means++", data=data)bench_k_means(KMeans(init='random', n_clusters=n_digits, n_init=10), name="random", data=data)# 在这种情况下,聚类中心的种子(seeding)是确定,因此我们# 只运行一次n_init = 1的kmeans算法pca = PCA(n_components=n_digits).fit(data)bench_k_means(KMeans(init=pca.components_, n_clusters=n_digits, n_init=1), name="PCA-based", data=data)print(82 * '_')# ############################################################################## 可视化PCA对数据进行降维的结果reduced_data = PCA(n_components=2).fit_transform(data)kmeans = KMeans(init='k-means++', n_clusters=n_digits, n_init=10)kmeans.fit(reduced_data)# 减少网格的步长,以提高VQ的质量。h = .02 # 网格[x_min,x_max] x [y_min,y_max]中的点。# 绘制决策边界。为此,我们将为每个边界分配一种颜色x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))# 使用训练好的模型获取网格中每个点的标签。Z = kmeans.predict(np.c_[xx.ravel(), yy.ravel()])# 将结果放入颜色图(color plot)Z = Z.reshape(xx.shape)plt.figure(1)plt.clf()plt.imshow(Z, interpolation='nearest', extent=(xx.min(), xx.max(), yy.min(), yy.max()), cmap=plt.cm.Paired, aspect='auto', origin='lower')plt.plot(reduced_data[:, 0], reduced_data[:, 1], 'k.', markersize=2)# 将聚类中心绘制为白色的Xcentroids = kmeans.cluster_centers_plt.scatter(centroids[:, 0], centroids[:, 1], marker='x', s=169, linewidths=3, color='w', zorder=10)plt.title('K-means clustering on the digits dataset (PCA-reduced data)\n' 'Centroids are marked with white cross')plt.xlim(x_min, x_max)plt.ylim(y_min, y_max)plt.xticks(())plt.yticks(())plt.show()
脚本的总运行时间:(0分钟1.122秒)
估计的内存使用量: 46 MB
下载Python源代码: plot_kmeans_digits.py
下载Jupyter notebook源代码: plot_kmeans_digits.ipynb
由Sphinx-Gallery生成的画廊
文壹由“伴编辑器”提供技术支持
☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏 文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图:欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。)