sklearn.mixture.GaussianMixture
(高斯混合模型,GMM)
GaussianMixture
(高斯混合模型,GMM)是 sklearn.mixture
提供的 基于概率密度分布的聚类方法,适用于 数据服从多个高斯分布的情况,可以用于 数据建模、异常检测、密度估计 等任务。
1. GaussianMixture
作用
- 适用于数据可能属于多个高斯分布的情况(如 金融数据、天气数据)。
- 可进行软聚类(Soft Clustering),不像
KMeans
只能进行硬聚类。 - 适用于异常检测、密度估计、数据生成(如 GMM 生成数据点)。
2. GaussianMixture
vs. KMeans
方法 | 适用情况 | 主要区别 |
---|---|---|
GaussianMixture | 数据服从多个高斯分布,软聚类 | 概率密度建模,可计算样本属于某个簇的概率 |
KMeans | 数据分布均匀,硬聚类 | 所有样本必须属于一个簇 |
示例
- KMeans:每个样本只属于一个簇。
- GMM:每个样本属于多个簇的概率不同(软聚类)。
3. GaussianMixture
代码示例
(1) 训练 GaussianMixture
聚类模型
from sklearn.mixture import GaussianMixture
import numpy as np
# 生成数据
X = np.random.rand(100, 2) # 100 个二维点
# 训练 GMM 聚类模型
model = GaussianMixture(n_components=3, random_state=42)
model.fit(X)
# 预测簇标签
labels = model.predict(X)
print("GMM 簇标签:", labels[:10])
解释
n_components=3
:设定高斯分布的数量(类似KMeans
的n_clusters
)。fit(X)
:训练GMM
,估计高斯分布的参数。predict(X)
:预测数据的簇标签。
4. GaussianMixture
主要参数
GaussianMixture(n_components=3, covariance_type="full", max_iter=100, random_state=42)
参数 | 说明 |
---|---|
n_components | 高斯分布的数量(类似 KMeans 的 n_clusters ) |
covariance_type | 协方差类型("full" 、"tied" 、"diag" 、"spherical" ) |
max_iter | 最大迭代次数(默认 100 ,用于 EM 算法收敛) |
random_state | 随机种子(确保结果可复现) |
5. covariance_type
(协方差类型)
covariance_type | 说明 | 适用情况 |
---|---|---|
"full" | 每个高斯分布都有完整的协方差矩阵 | 适用于任意形状的簇(默认) |
"tied" | 所有高斯分布共享同一个协方差矩阵 | 适用于簇大小相近的情况 |
"diag" | 协方差矩阵为对角矩阵 | 适用于簇形状类似的情况 |
"spherical" | 各个簇的协方差为标量 | 适用于球形簇 |
示例
for cov_type in ["full", "tied", "diag", "spherical"]:
model = GaussianMixture(n_components=3, covariance_type=cov_type, random_state=42)
model.fit(X)
print(f"covariance_type={cov_type}, 预测簇:\n", model.predict(X)[:10])
6. 计算样本属于某个簇的概率
GaussianMixture
可以计算 软聚类(Soft Clustering),即样本属于某个簇的概率:
probs = model.predict_proba(X)
print("前 5 个样本属于各个簇的概率:\n", probs[:5])
解释
predict_proba(X)
返回 样本属于每个簇的概率。
7. GaussianMixture
可视化
import matplotlib.pyplot as plt
import seaborn as sns
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=labels, palette="coolwarm")
plt.scatter(model.means_[:, 0], model.means_[:, 1], s=200, c="black", marker="X", label="Centroids")
plt.legend()
plt.title("GMM 聚类结果")
plt.show()
解释
- 黑色
X
标记高斯分布的均值(中心点)。
8. 计算聚类性能
from sklearn.metrics import silhouette_score
score = silhouette_score(X, labels)
print("轮廓系数:", score)
解释
silhouette_score(X, labels)
评估聚类效果(值越大越好)。
9. GaussianMixture
vs. KMeans
(环形数据)
from sklearn.datasets import make_moons
# 生成环形数据
X, _ = make_moons(n_samples=200, noise=0.05, random_state=42)
# 训练 KMeans
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=2, random_state=42)
labels_kmeans = kmeans.fit_predict(X)
# 训练 GMM
gmm = GaussianMixture(n_components=2, random_state=42)
labels_gmm = gmm.fit_predict(X)
# 可视化结果
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].scatter(X[:, 0], X[:, 1], c=labels_kmeans, cmap="coolwarm")
axes[0].set_title("KMeans 聚类")
axes[1].scatter(X[:, 0], X[:, 1], c=labels_gmm, cmap="coolwarm")
axes[1].set_title("GaussianMixture 聚类")
plt.show()
解释
KMeans
无法正确分割环形数据(因为它基于欧几里得距离)。GMM
适用于概率密度建模,可以更好地捕捉复杂数据分布。
10. 适用场景
- 密度估计(如异常检测、生成数据)。
- 金融建模(如股票市场分析)。
- 生物信息学(如基因表达分析)。
11. 结论
GaussianMixture
适用于密度估计、概率建模、软聚类任务,相比KMeans
更灵活,但计算复杂度更高。适用于 金融建模、异常检测、生物信息学。