【scikit-learn】sklearn.mixture.GaussianMixture 类:高斯混合模型(GMM)

sklearn.mixture.GaussianMixture(高斯混合模型,GMM)

GaussianMixture(高斯混合模型,GMM)是 sklearn.mixture 提供的 基于概率密度分布的聚类方法,适用于 数据服从多个高斯分布的情况,可以用于 数据建模、异常检测、密度估计 等任务。


1. GaussianMixture 作用

  • 适用于数据可能属于多个高斯分布的情况(如 金融数据、天气数据)。
  • 可进行软聚类(Soft Clustering),不像 KMeans 只能进行硬聚类。
  • 适用于异常检测、密度估计、数据生成(如 GMM 生成数据点)。

2. GaussianMixture vs. KMeans

方法适用情况主要区别
GaussianMixture数据服从多个高斯分布,软聚类概率密度建模,可计算样本属于某个簇的概率
KMeans数据分布均匀,硬聚类所有样本必须属于一个簇

示例

  • KMeans:每个样本只属于一个簇。
  • GMM:每个样本属于多个簇的概率不同(软聚类)。

3. GaussianMixture 代码示例

(1) 训练 GaussianMixture 聚类模型

from sklearn.mixture import GaussianMixture
import numpy as np

# 生成数据
X = np.random.rand(100, 2)  # 100 个二维点

# 训练 GMM 聚类模型
model = GaussianMixture(n_components=3, random_state=42)
model.fit(X)

# 预测簇标签
labels = model.predict(X)
print("GMM 簇标签:", labels[:10])

解释

  • n_components=3:设定高斯分布的数量(类似 KMeansn_clusters)。
  • fit(X):训练 GMM,估计高斯分布的参数。
  • predict(X):预测数据的簇标签。

4. GaussianMixture 主要参数

GaussianMixture(n_components=3, covariance_type="full", max_iter=100, random_state=42)
参数说明
n_components高斯分布的数量(类似 KMeansn_clusters
covariance_type协方差类型"full""tied""diag""spherical"
max_iter最大迭代次数(默认 100,用于 EM 算法收敛)
random_state随机种子(确保结果可复现)

5. covariance_type(协方差类型)

covariance_type说明适用情况
"full"每个高斯分布都有完整的协方差矩阵适用于任意形状的簇(默认)
"tied"所有高斯分布共享同一个协方差矩阵适用于簇大小相近的情况
"diag"协方差矩阵为对角矩阵适用于簇形状类似的情况
"spherical"各个簇的协方差为标量适用于球形簇

示例

for cov_type in ["full", "tied", "diag", "spherical"]:
    model = GaussianMixture(n_components=3, covariance_type=cov_type, random_state=42)
    model.fit(X)
    print(f"covariance_type={cov_type}, 预测簇:\n", model.predict(X)[:10])

6. 计算样本属于某个簇的概率

GaussianMixture 可以计算 软聚类(Soft Clustering),即样本属于某个簇的概率:

probs = model.predict_proba(X)
print("前 5 个样本属于各个簇的概率:\n", probs[:5])

解释

  • predict_proba(X) 返回 样本属于每个簇的概率

7. GaussianMixture 可视化

import matplotlib.pyplot as plt
import seaborn as sns

sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=labels, palette="coolwarm")
plt.scatter(model.means_[:, 0], model.means_[:, 1], s=200, c="black", marker="X", label="Centroids")
plt.legend()
plt.title("GMM 聚类结果")
plt.show()

解释

  • 黑色 X 标记高斯分布的均值(中心点)

8. 计算聚类性能

from sklearn.metrics import silhouette_score

score = silhouette_score(X, labels)
print("轮廓系数:", score)

解释

  • silhouette_score(X, labels) 评估聚类效果(值越大越好)。

9. GaussianMixture vs. KMeans(环形数据)

from sklearn.datasets import make_moons

# 生成环形数据
X, _ = make_moons(n_samples=200, noise=0.05, random_state=42)

# 训练 KMeans
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=2, random_state=42)
labels_kmeans = kmeans.fit_predict(X)

# 训练 GMM
gmm = GaussianMixture(n_components=2, random_state=42)
labels_gmm = gmm.fit_predict(X)

# 可视化结果
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].scatter(X[:, 0], X[:, 1], c=labels_kmeans, cmap="coolwarm")
axes[0].set_title("KMeans 聚类")
axes[1].scatter(X[:, 0], X[:, 1], c=labels_gmm, cmap="coolwarm")
axes[1].set_title("GaussianMixture 聚类")
plt.show()

解释

  • KMeans 无法正确分割环形数据(因为它基于欧几里得距离)。
  • GMM 适用于概率密度建模,可以更好地捕捉复杂数据分布。

10. 适用场景

  • 密度估计(如异常检测、生成数据)
  • 金融建模(如股票市场分析)
  • 生物信息学(如基因表达分析)

11. 结论

  • GaussianMixture 适用于密度估计、概率建模、软聚类任务,相比 KMeans 更灵活,但计算复杂度更高。适用于 金融建模、异常检测、生物信息学
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

彬彬侠

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值