EM算法推断混合高斯模型参数

写在前面

今天写了三段代码:
(1)getGMMdata 随机生成混合高斯模型数据
(2)EMGMM 解混合高斯模型参数
(3)测试脚本test.m

(可以修改迭代次数、收敛精度、分布参数观察算法的效果。)

随机生成混合高斯模型的数据

% data是将所有数据混合在一起,data_class是一个cell,将不同高斯分布下的数据混合存放。
% 例如 data=getGMMdata([180,160],[5,3],[0.4,0.6],1000);
%
% 公众号【数学建模公会】,HCLO4原创,20190821.

function [data,data_class]=getGMMdata(mu,sigma,ratio,sampleSize)

numClass=length(mu);
sampleSize_class=zeros(1,numClass);
sampleSize_class(1:numClass-1)=round(sampleSize*ratio(1:numClass-1));
sampleSize_class(end)=sampleSize-sum(sampleSize_class);

data=zeros(sampleSize,1);
data_class=cell(1,numClass);

p=1;
for i=1:numClass
    tempdata=normrnd(mu(i),sigma(i),sampleSize_class(i),1);
    data(p:p+sampleSize_class(i)-1)=tempdata;
    data_class{i}=tempdata;
    p=p+sampleSize_class(i);
end

EM算法估计混合高斯分布的参数

% 输入
% GMMdata:混合高斯分布数据
% Epsilon:迭代精度
% varargin{1}: Maximal iteration steps. (默认100)
%
% 输出
%:mu,均值,sigma,标准差,ratio, 混合比例

% 【数学建模公会】,HCLO4原创,20190821.

function [mu,sigma,ratio]=EMGMM(GMMdata,Epsilon,varargin)

if nargin==2
    maxIter=100;
else
    maxIter=varargin{1};
end

初始化, E步骤

numClass=2;
% mu=randn(1,numClass);
mu=normrnd(mean(GMMdata),std(GMMdata),1,numClass);
varval=var(GMMdata)*ones(1,numClass);
ratio=[0.3,0.7];
sampleSize=length(GMMdata);
postp=zeros(sampleSize,numClass);

for step=1:maxIter

    for i=1:sampleSize

        postp(i,:)=exp(-1*(GMMdata(i)-mu).^2./(2*varval))./sqrt(2*pi*varval);
    end

    sumPostp=sum(postp,2);
    for i=1:sampleSize
        postp(i,:) = postp(i,:)/sumPostp(i);
    end

EM算法,M步骤

preMu=mu;
    preSigma=varval;
    preRatio=ratio;
    for k=1:numClass

        P1=sum(postp(:,k));
        P2=postp(:,k)'* GMMdata;        
        mu(k)=P2/P1;
        ratio(k)=P1/sampleSize;  
        P3=postp(:,k)'*(GMMdata-mu(k)).^2;      
        varval(k) = P3/P1;

    end

% 迭代终止条件:Epsilon

    if sum(abs(mu - preMu)) < Epsilon && ...
        sum(abs(varval - preSigma)) < Epsilon && ...
        sum(abs(ratio - preRatio)) < Epsilon 

        sigma=sqrt(varval);

        break

    end

    sigma=sqrt(varval);

end

测试脚本

mu=[180,160];
sigma=[5,3];
ratio=[0.4,0.6];

sampleSize=10000;
[GMMdata,data_class]=getGMMdata(mu,sigma,ratio,sampleSize);

画出随机产生的数据的实际分布

figure
hold on
for i=1:length(data_class)
    [y,x]=ksdensity(data_class{i});
    plot(x,y,'LineWidth',3)
end

xlabel('Real distribution')


Epsilon=10^-4;
maxIter=100;
[mu,sigma,ratio]=EMGMM(GMMdata,Epsilon,maxIter);

画出推断得到的分布

for i=1:length(mu)
    tempdata=normrnd(mu(i),sigma(i),1,10000);
    [y,x]=ksdensity(tempdata);
    plot(x,y,'--','LineWidth',3)
end

saveas(gca,)

在这里插入图片描述

获取源代码 或 想了解更多内容,欢迎关注公众号:数学建模公会
在这里插入图片描述

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值