注意:单击此处https://urlify.cn/3auQBf下载完整的示例代码,或通过Binder在浏览器中运行此示例
本示例构建瑞士卷数据集,并在该数据集上运行分层聚类(Hierarchical clustering)。 有关更多信息,请参见分层聚类。 第一步,在不对结构进行连通限制的情况下仅基于距离来进行分层聚类;而在第二步中,聚类仅限于k最近邻图:它是具有结构优先级的分层聚类。 在没有连通性约束的情况下学习到的某些聚类是不遵循瑞士卷结构的,并延伸到集合管(manifolds)的不同折叠处(folds)。相反,当面对相反的连通性约束时,从瑞士卷数据集中可以形成一个很好的聚类分割。 输出:Compute unstructured hierarchical clustering...
Elapsed time: 0.05s
Number of points: 1500
Compute structured hierarchical clustering...
Elapsed time: 0.10s
Number of points: 1500
# 作者 : Vincent Michel, 2010
# Alexandre Gramfort, 2010
# Gael Varoquaux, 2010
# 许可证: BSD 3 clause
print(__doc__)
import time as time
import numpy as np
import matplotlib.pyplot as plt
import mpl_toolkits.mplot3d.axes3d as p3
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import make_swiss_roll
# #############################################################################
# 生成数据(瑞士卷数据集)
n_samples = 1500
noise = 0.05
X, _ = make_swiss_roll(n_samples, noise)
# 使其更薄
X[:, 1] *= .5
# #############################################################################
# 计算聚类
print("Compute unstructured hierarchical clustering...")
st = time.time()
ward = AgglomerativeClustering(n_clusters=6, linkage='ward').fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print("Elapsed time: %.2fs" % elapsed_time)
print("Number of points: %i" % label.size)
# #############################################################################
# 绘制结果
fig = plt.figure()
ax = p3.Axes3D(fig)
ax.view_init(7, -80)
for l in np.unique(label):
ax.scatter(X[label == l, 0], X[label == l, 1], X[label == l, 2],
color=plt.cm.jet(np.float(l) / np.max(label + 1)),
s=20, edgecolor='k')
plt.title('Without connectivity constraints (time %.2fs)' % elapsed_time)
# #############################################################################
# 定义数据的结构A。10个最近邻
from sklearn.neighbors import kneighbors_graph
connectivity = kneighbors_graph(X, n_neighbors=10, include_self=False)
# #############################################################################
# 计算聚类
print("Compute structured hierarchical clustering...")
st = time.time()
ward = AgglomerativeClustering(n_clusters=6, connectivity=connectivity,
linkage='ward').fit(X)
elapsed_time = time.time() - st
label = ward.labels_
print("Elapsed time: %.2fs" % elapsed_time)
print("Number of points: %i" % label.size)
# #############################################################################
# 绘制结果
fig = plt.figure()
ax = p3.Axes3D(fig)
ax.view_init(7, -80)
for l in np.unique(label):
ax.scatter(X[label == l, 0], X[label == l, 1], X[label == l, 2],
color=plt.cm.jet(float(l) / np.max(label + 1)),
s=20, edgecolor='k')
plt.title('With connectivity constraints (time %.2fs)' % elapsed_time)
plt.show()
脚本的总运行时间:(0分钟0.646秒)
估计的内存使用量:25 MB
下载Python源代码: plot_ward_structured_vs_unstructured.py
下载Jupyter notebook源代码: plot_ward_structured_vs_unstructured.ipynb
由Sphinx-Gallery生成的画廊
文壹由“伴编辑器”提供技术支持
☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏 文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图:欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。)