混合高斯模型GMMS的EM算法实现

    最近学习机器学习课程,需要实现混合高斯模型的求解。因为混合高斯模型的优化函数,似然函数难以求导优化,所以使用EM算法优化。下面是EM算法的实现。主要参考Andrew Ng的讲义上的公式实现的。




其中,j代表第几个类别,i代表第几个样本数,phi(j)代表类别出现的概率,即p(z(i)=j; phi)。数据满足的模型是混合高斯模型,即p(x|z)是高斯分布,由多个高斯分布以一定的权重相加。下面用matlab实现混合高斯模型。

%by alex
% 说明只需要改变数据X,迭代次数TIME,要分类的类别M即可
clear all;clc


%数据一
% MU1    = [1 2];
% SIGMA1 = [1 0; 0 0.5];
% MU2    = [-1 -1];
% SIGMA2 = [1 0; 0 1];
% X = [mvnrnd(MU1, SIGMA1, 1000); mvnrnd(MU2, SIGMA2, 1000)];
% 
% scatter(X(:,1), X(:,2), 50, '.')


% 数据二
MU1    = [1 2];
SIGMA1 = [1 1; 1 2];
MU2    = [-1 -1];
SIGMA2 = [1 0; 0 1];
MU3    = [-1 -5];
SIGMA3 = [1 0; 0 3];
X = [mvnrnd(MU1, SIGMA1, 1000); mvnrnd(MU2, SIGMA2, 1000); mvnrnd(MU3, SIGMA3, 1000);];


%数据三
% MU1    = [1 2 3];
% SIGMA1 = [1 0 0; 0 0.5 0;0 0 3];
% MU2    = [-1 -1 -1];
% SIGMA2 = [1 0 0; 0 1 0;0 0 1];
% MU3    = [-1 -5 -2];
% SIGMA3 = [1 0 0; 0 3 0; 0 0 2];
% X = [mvnrnd(MU1, SIGMA1, 1000); mvnrnd(MU2, SIGMA2, 1000); mvnrnd(MU3, SIGMA3, 1000);];


%%%%%%%%%%%%%% EM algorithm %%%%%%%%%%%
%用M改变类别数,TIME改变迭代次数


%求数据的整体协方差Covariance
CC = cov(X);


M = 3;   %M代表类别
TIME = 1000; %迭代次数


[N, D] = size(X);  %N为样本总数,D为特征个数
%初始化迭代参数
phi = rand(1,M);
phi = phi./sum(phi);
U = zeros(1,D);
init_u = zeros(M,D);
for i=1:D
    U(i) = mean(X(:,i));
end
for i=1:M
    init_u(i,:) = U(1:D)+i-1;
end
C = cell(1,M);
for i = 1:M
    C{i} = CC;
end


w = zeros(length(X),M);
mol_w = zeros(length(X),M);
den_w = zeros(length(X),M);
last_time = 0;
for k=1:TIME
    %E_step
    tmp = zeros(N,1);
    for i = 1:M
        tmp = tmp + phi(i)*mvnpdf( X, init_u(i,:), C{i} ); %求权系数的分母
    end
    for j = 1:M
        mol_w(:,j) = phi(j)*mvnpdf(X, init_u(j,:), C{j} );   %求权系数的分子
        den_w(:,j) = tmp;
        w = mol_w./den_w;
    end
    %M_step
    last_time = init_u(1,1);    %用于判断迭代次数
    for j = 1:M
        phi(j) = sum(w(:,j))/length(X);        %更新权重
        init_u(j,:) = w(:,j)'*X/sum(w(:,j));   %更新均值
        X_mean = X - repmat(init_u(j,:), N ,1);
        temp_w = repmat(w(:,j)',D,1);
        tmp_C = temp_w.*X_mean'*X_mean;
        C{j} = tmp_C./sum(w(:,j));              %更新方差
    end
    if(abs(init_u(1,1)-last_time) <= 0.000001)  %迭代次数判断
        iter_time = k
        break;
    end
end


%输出参数
phi
for j=1:M
    fprintf('第 %d 类',j)
    mu = init_u(j,:)
    sigma = C{j}
end


%在数据的特征为2维时,画出三维曲面图
if(D == 2)
    x = sym('x',[1,D]);     %定义符号数组
    pdf_u = sym('pdf_u',[M,D]);  %定义符号矩阵
    pdf_out = sym('pdf_out',[1,M]);
    for j=1:M
        pdf_u(j,:) = x - init_u(j,:);
        pdf_out(j) = phi(j)*exp(-0.5*pdf_u(j,:)*inv(C{j})*pdf_u(j,:)')/(2*pi*sqrt(det(C{j})));
    end
    g = sum(pdf_out);
    
    figure;
    ezsurf(g)
    title('混合高斯概率密度函数')
    
    figure
    ezcontour(g)
    title('混合高斯概率密度函数俯视图')
end


  • 0
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值