混合高斯模型
基本概念
混合高斯模型指的是将多个高斯分布进行加权叠加,在数学上就是将多个不同的高斯分布的概率密度函数进行加权叠加,形成一个新的概率密度函数表达式,用于描述当前情形下的样本分布:
p
(
x
)
=
∑
k
=
1
K
α
k
N
(
μ
k
,
Σ
k
)
,
其中
∑
k
=
1
K
α
k
=
1
p(x)=\sum_{k=1}^K\alpha_kN(\mu_k,\Sigma_k),其中\sum_{k=1}^K\alpha_k=1
p(x)=k=1∑KαkN(μk,Σk),其中k=1∑Kαk=1
高斯混合模型求解
混合高斯模型中含有隐变量,就是说样本不知道是由哪一个高斯模型产生。因此求解混合高斯模型需要用到EM算法。EM算法需要求解的参数是:
θ
=
(
p
1
,
p
2
,
⋯
,
p
K
,
μ
1
,
μ
2
,
⋯
,
μ
K
,
Σ
1
,
Σ
2
,
⋯
,
Σ
K
)
\theta=(p_1,p_2,\cdots,p_K,\mu_1,\mu_2,\cdots,\mu_K,\Sigma_1,\Sigma_2,\cdots,\Sigma_K)
θ=(p1,p2,⋯,pK,μ1,μ2,⋯,μK,Σ1,Σ2,⋯,ΣK)
经过EM算法得到
θ
\theta
θ的迭代结果。
μ
k
t
+
1
=
∑
i
=
1
N
P
(
z
i
=
C
k
∣
x
i
,
θ
(
t
)
)
x
i
∑
i
=
1
N
P
(
z
i
=
C
k
∣
x
i
,
θ
(
t
)
)
Σ
k
t
+
1
=
∑
i
=
1
N
P
(
z
i
=
C
k
∣
x
i
,
θ
(
t
)
)
(
x
i
−
μ
k
(
t
)
)
(
x
i
−
μ
k
(
t
)
)
T
∑
i
=
1
N
P
(
z
i
=
C
k
∣
x
i
,
θ
(
t
)
)
P
(
z
i
∣
x
i
,
θ
(
t
)
=
p
z
i
(
t
)
N
(
x
i
∣
μ
z
i
(
t
)
,
Σ
z
i
(
t
)
)
∑
k
=
1
K
p
k
(
t
)
N
(
x
i
∣
μ
k
(
t
)
,
Σ
k
(
t
)
)
)
\begin{aligned} &\mu_k^{t+1}=\frac{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})x_i}{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})} \\ \! \\ &\Sigma_k^{t+1}=\frac{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})(x_i-\mu_k^{(t)})(x_i-\mu_k^{(t)})^T}{\sum_{i=1}^NP(z_i=C_k|x_i,\theta^{(t)})} \\ \! \\ &P(z_i|x_i,\theta^{(t)}=\frac{p_{z_i}^{(t)}N(x_i|\mu_{z_i}^{(t)},\Sigma_{z_i}^{(t)})}{\sum_{k=1}^{K}p_k^{(t)}N(x_i|\mu_k^{(t)},\Sigma_k^{(t)})}) \end{aligned}
μkt+1=∑i=1NP(zi=Ck∣xi,θ(t))∑i=1NP(zi=Ck∣xi,θ(t))xiΣkt+1=∑i=1NP(zi=Ck∣xi,θ(t))∑i=1NP(zi=Ck∣xi,θ(t))(xi−μk(t))(xi−μk(t))TP(zi∣xi,θ(t)=∑k=1Kpk(t)N(xi∣μk(t),Σk(t))pzi(t)N(xi∣μzi(t),Σzi(t)))
混合高斯分布模型分类代码实现
# -*- coding: utf-8 -*-
# @Use :
# @Time : 2022/8/27 21:30
# @FileName: MixGaussian.py
# @Software: PyCharm
import matplotlib.pyplot as plt
from sklearn.mixture import GaussianMixture
from sklearn.datasets._samples_generator import make_blobs
X, y_true = make_blobs(n_samples=1000, centers=4)
fig, ax = plt.subplots(1, 2, sharex='row')
ax[0].scatter(X[:, 0], X[:, 1], s=5, alpha=0.5)
gmm = GaussianMixture(n_components=4)
gmm.fit(X)
print(gmm.weights_) # 权重
print(gmm.means_) # 均值
print(gmm.covariances_) # 协方差
print(gmm.predict_proba(X)[:10].round(5))
labels = gmm.predict(X)
ax[1].scatter(X[:, 0], X[:, 1], s=5, alpha=0.5, c=labels, cmap='viridis')
ax[1].grid(ls='--')
plt.show()