此示例的目标是直观地显示度量(metrics)的行为,而不是找到好的手写数字的聚类,这就是为什么该示例适用于二维嵌入的原因。
该示例向我们展示的是层次聚类的行为-“富人越来富(rich getting richer)”,这往往会导致大小不均匀的聚类。对于平均链接策略,此行为是明显会发生的,它以几个单聚类结束,而在单链接的情况下,我们得到一个中央聚类,所有其他聚类均从中央聚类边缘周围的噪声点中得出。 sphx_glr_plot_digits_linkage_002 sphx_glr_plot_digits_linkage_003 输出:Computing embedding
Done.
ward : 0.31s
average : 0.30s
complete : 0.29s
single : 0.21s
# 作者: Gael Varoquaux# 许可证: BSD 3 clause (C) INRIA 2014
print(__doc__)from time import timeimport numpy as npfrom scipy import ndimagefrom matplotlib import pyplot as pltfrom sklearn import manifold, datasets
X, y = datasets.load_digits(return_X_y=True)
n_samples, n_features = X.shape
np.random.seed(0)def nudge_images(X, y):# 拥有更大的数据集可以更清楚地显示这个方法的行为,# 但我们仅将数据集的大小乘以2,# 因为层次聚类方法的计算成本是# 是n_samples的超线性(super-linear)
shift = lambda x: ndimage.shift(x.reshape((8, 8)),.3 * np.random.normal(size=2),
mode='constant',
).ravel()
X = np.concatenate([X, np.apply_along_axis(shift, 1, X)])
Y = np.concatenate([y, y], axis=0)return X, Y
X, y = nudge_images(X, y)#----------------------------------------------------------------------# 可视化聚类def plot_clustering(X_red, labels, title=None):
x_min, x_max = np.min(X_red, axis=0), np.max(X_red, axis=0)
X_red = (X_red - x_min) / (x_max - x_min)
plt.figure(figsize=(6, 4))for i in range(X_red.shape[0]):
plt.text(X_red[i, 0], X_red[i, 1], str(y[i]),
color=plt.cm.nipy_spectral(labels[i] / 10.),
fontdict={'weight': 'bold', 'size': 9})
plt.xticks([])
plt.yticks([])if title is not None:
plt.title(title, size=17)
plt.axis('off')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])#----------------------------------------------------------------------# 手写数字识别数据集的二维嵌入
print("Computing embedding")
X_red = manifold.SpectralEmbedding(n_components=2).fit_transform(X)
print("Done.")from sklearn.cluster import AgglomerativeClusteringfor linkage in ('ward', 'average', 'complete', 'single'):
clustering = AgglomerativeClustering(linkage=linkage, n_clusters=10)
t0 = time()
clustering.fit(X_red)
print("%s :\t%.2fs" % (linkage, time() - t0))
plot_clustering(X_red, clustering.labels_, "%s linkage" % linkage)
plt.show()
脚本的总运行时间:(0分钟25.325秒)
估计的内存使用量: 152 MB
下载Python源代码: plot_segmentation_toy.py
下载Jupyter notebook源代码: plot_segmentation_toy.ipynb
本文由“壹伴编辑器”提供技术支持
☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图:欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。)