本示例演示高斯混合模型的几种协方差类型。 有关估计器的更多信息,请参见高斯混合模型。 尽管GMM通常用于聚类,但我们可以将获得的聚类与数据集中的实际类别进行比较。我们使用训练集中类别的均值来初始化高斯函数的均值,以便使其比较更加有效。 我们在鸢尾花植物数据集上使用各种GMM协方差类型,来在训练和测试数据上绘制预测标签。我们按性能递增的顺序将GMM的球面(spherical),对角线(diagonal),完全(full)和束缚(tied)协方差矩阵进行比较。完全(full)协方差在总体上表现最好,但是它在小型数据集上有过拟合的倾向,并且不能很好地泛化到测试集的数据上。 在图中,训练数据的显示是点,而测试数据的显示是一个‘十’。鸢尾花植物数据集是四维的,本文仅显示前两个维度,因此某些点在其他维度可能是分开的。 sphx_glr_plot_gmm_sin_001 import itertoolsimport numpy as npfrom scipy import linalgimport matplotlib.pyplot as pltimport matplotlib as mplfrom sklearn import mixtureprint(__doc__)color_iter = itertools.cycle(['navy', 'c', 'cornflowerblue', 'gold', 'darkorange'])def plot_results(X, Y, means, covariances, index, title): splot = plt.subplot(5, 1, 1 + index) for i, (mean, covar, color) in enumerate(zip( means, covariances, color_iter)): v, w = linalg.eigh(covar) v = 2. * np.sqrt(2.) * np.sqrt(v) u = w[0] / linalg.norm(w[0]) # DP不会使用每一个它可以访问的分量,除非它需要使用。 # 我们不应该绘制冗余的分量。 if not np.any(Y == i): continue plt.scatter(X[Y == i, 0], X[Y == i, 1], .8, color=color) # 绘制椭圆以显示高斯分量 angle = np.arctan(u[1] / u[0]) angle = 180. * angle / np.pi # 转换为度数 ell = mpl.patches.Ellipse(mean, v[0], v[1], 180. + angle, color=color) ell.set_clip_box(splot.bbox) ell.set_alpha(0.5) splot.add_artist(ell) plt.xlim(-6., 4. * np.pi - 6.) plt.ylim(-5., 5.) plt.title(title) plt.xticks(()) plt.yticks(())def plot_samples(X, Y, n_components, index, title): plt.subplot(5, 1, 4 + index) for i, color in zip(range(n_components), color_iter): # DP不会使用每一个它可以访问的分量,除非它需要使用。 # 我们不应该绘制冗余的分量。 if not np.any(Y == i): continue plt.scatter(X[Y == i, 0], X[Y == i, 1], .8, color=color) plt.xlim(-6., 4. * np.pi - 6.) plt.ylim(-5., 5.) plt.title(title) plt.xticks(()) plt.yticks(())# 参数n_samples = 100# 根据正弦曲线生成随机样本np.random.seed(0)X = np.zeros((n_samples, 2))step = 4. * np.pi / n_samplesfor i in range(X.shape[0]): x = i * step - 6. X[i, 0] = x + np.random.normal(0, 0.1) X[i, 1] = 3. * (np.sin(x) + np.random.normal(0, .2))plt.figure(figsize=(10, 10))plt.subplots_adjust(bottom=.04, top=0.95, hspace=.2, wspace=.05, left=.03, right=.97)# 使用十个分量拟合用期望最大化(EM)算法的高斯混合模型gmm = mixture.GaussianMixture(n_components=10, covariance_type='full', max_iter=100).fit(X)plot_results(X, gmm.predict(X), gmm.means_, gmm.covariances_, 0, 'Expectation-maximization')dpgmm = mixture.BayesianGaussianMixture( n_components=10, covariance_type='full', weight_concentration_prior=1e-2, weight_concentration_prior_type='dirichlet_process', mean_precision_prior=1e-2, covariance_prior=1e0 * np.eye(2), init_params="random", max_iter=100, random_state=2).fit(X)plot_results(X, dpgmm.predict(X), dpgmm.means_, dpgmm.covariances_, 1, "Bayesian Gaussian mixture models with a Dirichlet process prior " r"for $\gamma_0=0.01$.")X_s, y_s = dpgmm.sample(n_samples=2000)plot_samples(X_s, y_s, dpgmm.n_components, 0, "Gaussian mixture with a Dirichlet process prior " r"for $\gamma_0=0.01$ sampled with $2000$ samples.")dpgmm = mixture.BayesianGaussianMixture( n_components=10, covariance_type='full', weight_concentration_prior=1e+2, weight_concentration_prior_type='dirichlet_process', mean_precision_prior=1e-2, covariance_prior=1e0 * np.eye(2), init_params="kmeans", max_iter=100, random_state=2).fit(X)plot_results(X, dpgmm.predict(X), dpgmm.means_, dpgmm.covariances_, 2, "Bayesian Gaussian mixture models with a Dirichlet process prior " r"for $\gamma_0=100$")X_s, y_s = dpgmm.sample(n_samples=2000)plot_samples(X_s, y_s, dpgmm.n_components, 1, "Gaussian mixture with a Dirichlet process prior " r"for $\gamma_0=100$ sampled with $2000$ samples.")plt.show() 脚本的总运行时间:(0分钟0.614秒) 估计的内存使用量: 8 MB 下载Python源代码: plot_gmm_sin.py 下载Jupyter notebook源代码: plot_gmm_sin.ipynb 由Sphinx-Gallery生成的画廊 ☆☆☆为方便大家查阅,小编已将scikit-learn学习路线专栏文章统一整理到公众号底部菜单栏,同步更新中,关注公众号,点击左下方“系列文章”,如图: 欢迎大家和我一起沿着scikit-learn文档这条路线,一起巩固机器学习算法基础。(添加微信:mthler,备注:sklearn学习,一起进【sklearn机器学习进步群】开启打怪升级的学习之旅。)