# EM算法与GMM的训练应用

### EM算法

#### M-step:优化下界，即求取

argmaxmi=1kz(i)=1Qi(Z(i))log(p(x(i)|z(i);uz(i),Σz(i))p(z(i);ϕ)Qi(Z(i)))$argmax\sum_{i=1}^m\sum_{z^{(i)}=1}^k {Q_i(Z^{(i)})} log(\frac {p(x^{(i)}|z^{(i)};u_{z^{(i)}},\Sigma_{z^{(i)}})p(z^{(i)};\phi)} {Q_i(Z^{(i)})})$

### Matlab实现

#### 根据以上推导，可以很容易实现EM算法估计GMM参数。现以1维数据2个高斯混合概率密度估计作为实例，详细代码如下所示。

% fitting_a_gmm.m
% EM算法简单实现
% Hongliang He 2014/03
clear
close all
clc

% generate data
len1 = 1000;
len2 = fix(len1 * 1.5);
data = [normrnd(0, 1, [1 len1])  normrnd(4, 2, [1 len2])] + 0.1*rand([1 len1+len2]);
data_len = length(data);

% use EM algroithm to estimate the parameters
ite_cnt = 100000;     % maximum iterations
max_err = 1e-5;  % 迭代停止条件

% soft boundary EM algorithm
z0 = 0.5;   % prior probability
z1 = 1 - z0;
u  = mean(data);
u0 = 1.2 * u;
u1 = 0.8 * u;
sigma0 = 1;
sigma1 = 1;

itetation = 0;
while( itetation < ite_cnt )
% init papameters
w0 = zeros(1, data_len);  % Qi, postprior
w1 = zeros(1, data_len);

% E-step, update Qi/w to get a tight lower bound
for k1=1:data_len
p0 =  z0 * gauss(data(k1), u0, sigma0);
p1 =  z1 * gauss(data(k1), u1, sigma1);
p = p0 / (p0 + p1);

if p0 == 0 && p1 == 0
%p = w0(k1);
dist0 = (data(k1)-u0).^2;
dist1 = (data(k1)-u1).^2;
if dist0 > dist1
p = w0(k1) + 0.01;
elseif dist0 == dist1
else
p = w0(k1) - 0.01;
end
end
if p > 1
p = 1;
elseif p < 0
p = 0;
end

w0(k1) = p;  % postprior
w1(k1) = 1 - w0(k1);
end

% record the pre-value
old_u0 = u0;
old_u1 = u1;
old_sigma0 = sigma0;
old_sigma1 = sigma1;

% M-step, maximize the lower bound
u0 = sum(w0 .* data) / sum(w0);
u1 = sum(w1 .* data) / sum(w1);
sigma0 = sqrt( sum(w0 .* (data - u0).^2) / sum(w0));
sigma1 = sqrt( sum(w1 .* (data - u1).^2) / sum(w1));
z0 = sum(w0) / data_len;
z1 = sum(w1) / data_len;

% is convergance
if mod(itetation, 10) == 0
sprintf('%d: u0=%f,d0=%f u1=%f,d1=%f\n',itetation, …
u0,sigma0,u1,sigma1)
end

d_u0 = abs(u0 - old_u0);
d_u1 = abs(u1 - old_u1);
d_sigma0 = abs(sigma0 - old_sigma0);
d_sigma1 = abs(sigma1 - old_sigma1);

% 迭代停止判断
if d_u0 < max_err && d_u1 < max_err && …
d_sigma0 < max_err && d_sigma1 < max_err
clc
sprintf('ite = %d, final value is', itetation)
sprintf('u0=%f,d0=%f  u1=%f,d1=%f\n', u0,sigma0,u1,sigma1)
break;
end

itetation = itetation + 1;
end

% compare
my_hist(data, 20);
hold on;
mi = min(data);
mx = max(data);
t  = linspace(mi, mx, 100);
y  = z0*gauss(t, u0, sigma0) + z1*gauss(t, u1, sigma1);
plot(t, y, 'r', 'linewidth', 5);

% gauss.m
% 1维高斯函数
% Hongliang He 2014/03
function y = gauss(x, u, sigma)
y = exp( -0.5*(x-u).^2/sigma.^2 ) ./ (sqrt(2*pi)*sigma);
end

% my_hist.m
% 用直方图估计概率密度
% 2013/03
function my_hist(data, cnt)
dat_len = length(data);
if dat_len < cnt*5
error('There are not enough data!\n')
end

mi = min(data);
ma = max(data);
if ma <= mi
error('sorry, there is only one type of data\n')
end

dt = (ma - mi) / cnt;
t  = linspace(mi, ma, cnt);
for k1=1:cnt-1
y(k1) = sum( data >= t(k1) & data < t(k1+1) );
end
y = y ./ dat_len / dt;
t = t + 0.5*dt;
bar(t(1:cnt-1), y);
%stem(t(1:cnt-1), y)
end