import numpy as np
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import dendrogram
from sklearn.cluster import AgglomerativeClustering
from sklearn.datasets import load_iris
# **kwargs 是一个 Python 中的特殊语法,
# 用于在函数参数中接收不定数量的关键字参数。
# 并将其作为一个列表类型传递给函数。
def plot_dendrogram(model, **kwargs):
# 创建每个节点下样本的计数
counts = np.zeros(model.children_.shape[0])
n_samples = len(model.labels_)
for i, merge in enumerate(model.children_):
current_count = 0
for child_idx in merge:
if child_idx < n_samples:
current_count += 1 # 叶节点
else:
current_count += counts[child_idx - n_samples]
counts[i] = current_count
# 创建联接矩阵,如何理解这个矩阵的每个元素,见后边的心得
linkage_matrix = np.column_stack(
[model.children_, model.distances_, counts]
).astype(float)
# 绘制对应的树状图
dendrogram(linkage_matrix, **kwargs)
# 还是以鸢尾花数据为例
iris = load_iris()
# 提取鸢尾花数据的特征
X = iris.data
# 凝聚层次聚类
# 设置distance_threshold=0确保计算完整的树
# n_clusters:要形成的聚类数量,设置为 None 计算到最后合并成一个类。
model = AgglomerativeClustering(distance_threshold=0, n_clusters=None)
# fit进行数据拟合后,可以通过 labels_ 属性获取每个样本的类别标签。
model = model.fit(X)
plt.title("Hierarchical Clustering Dendrogram")
# 绘制树状图的前三个层级
plot_dendrogram(model, truncate_mode="level", p=3)
plt.xlabel("Number of points in node (or index of point if no parenthesis)")
plt.show()
心得:
1.
Q:AgglomerativeClustering(distance_threshold=0, n_clusters=None)中n_clusters设置为none,表示所有数据划分成一个类,那还有什么聚类的必要了吗?
A:将 n_clusters 设置为 None 主要用于对数据的层次结构进行分析,而不是真正意义上的聚类,有时被称为层次聚类的树状图。
2.
Q:distance_threshold=0 和 n_clusters=None 两者有何区别?看似一样
A:当 distance_threshold 设为非 0 值时,聚类算法将根据 distance_threshold 的值来分割数据,并形成多个聚类。而当将 distance_threshold 设为 0 时,类似于将 n_clusters 设为 None,聚类算法一直合并类别,直到所有的样本都被合并成一个大类。前者的优先级更高。
3.AgglomerativeClustering 与树结构(也称为树状图)密切相关。层次聚类算法的最终结果是一棵树状图,其中树的每个节点表示一个聚类,叶节点表示单个样本。
4.AgglomerativeClustering 有一个重要的属性 children_是一个数组,用于存储层次聚类过程中每个样本或聚类节点的合并信息。其返回值是一个二维数组,形状为 (n_samples-1, 2)。数组的每一行表示一次合并操作,在合并的过程中,最相似的两个样本或聚类会成为一个新的节点。最终,children_ 数组中的最后一行将保存整个层次聚类的结果。
5.要最终画出树形图,关键是要得到连接矩阵。
6.如何理解这个矩阵?
矩阵形如:
[
[cluster_index_1, cluster_index_2, distance, num_members],
[cluster_index_3, cluster_index_4, distance, num_members],
...
]
其中,cluster_index_1 和 cluster_index_2 表示被合并的两个聚类的索引,distance 表示这两个聚类之间的距离或相似度,num_members 表示新合并的聚类中包含的样本数量。