写在前面
今天写了三段代码:
(1)getGMMdata 随机生成混合高斯模型数据
(2)EMGMM 解混合高斯模型参数
(3)测试脚本test.m
(可以修改迭代次数、收敛精度、分布参数观察算法的效果。)
随机生成混合高斯模型的数据
% data是将所有数据混合在一起,data_class是一个cell,将不同高斯分布下的数据混合存放。
% 例如 data=getGMMdata([180,160],[5,3],[0.4,0.6],1000);
%
% 公众号【数学建模公会】,HCLO4原创,20190821.
function [data,data_class]=getGMMdata(mu,sigma,ratio,sampleSize)
numClass=length(mu);
sampleSize_class=zeros(1,numClass);
sampleSize_class(1:numClass-1)=round(sampleSize*ratio(1:numClass-1));
sampleSize_class(end)=sampleSize-sum(sampleSize_class);
data=zeros(sampleSize,1);
data_class=cell(1,numClass);
p=1;
for i=1:numClass
tempdata=normrnd(mu(i),sigma(i),sampleSize_class(i),1);
data(p:p+sampleSize_class(i)-1)=tempdata;
data_class{i}=tempdata;
p=p+sampleSize_class(i);
end
EM算法估计混合高斯分布的参数
% 输入
% GMMdata:混合高斯分布数据
% Epsilon:迭代精度
% varargin{1}: Maximal iteration steps. (默认100)
%
% 输出
%:mu,均值,sigma,标准差,ratio, 混合比例
% 【数学建模公会】,HCLO4原创,20190821.
function [mu,sigma,ratio]=EMGMM(GMMdata,Epsilon,varargin)
if nargin==2
maxIter=100;
else
maxIter=varargin{1};
end
初始化, E步骤
numClass=2;
% mu=randn(1,numClass);
mu=normrnd(mean(GMMdata),std(GMMdata),1,numClass);
varval=var(GMMdata)*ones(1,numClass);
ratio=[0.3,0.7];
sampleSize=length(GMMdata);
postp=zeros(sampleSize,numClass);
for step=1:maxIter
for i=1:sampleSize
postp(i,:)=exp(-1*(GMMdata(i)-mu).^2./(2*varval))./sqrt(2*pi*varval);
end
sumPostp=sum(postp,2);
for i=1:sampleSize
postp(i,:) = postp(i,:)/sumPostp(i);
end
EM算法,M步骤
preMu=mu;
preSigma=varval;
preRatio=ratio;
for k=1:numClass
P1=sum(postp(:,k));
P2=postp(:,k)'* GMMdata;
mu(k)=P2/P1;
ratio(k)=P1/sampleSize;
P3=postp(:,k)'*(GMMdata-mu(k)).^2;
varval(k) = P3/P1;
end
% 迭代终止条件:Epsilon
if sum(abs(mu - preMu)) < Epsilon && ...
sum(abs(varval - preSigma)) < Epsilon && ...
sum(abs(ratio - preRatio)) < Epsilon
sigma=sqrt(varval);
break
end
sigma=sqrt(varval);
end
测试脚本
mu=[180,160];
sigma=[5,3];
ratio=[0.4,0.6];
sampleSize=10000;
[GMMdata,data_class]=getGMMdata(mu,sigma,ratio,sampleSize);
画出随机产生的数据的实际分布
figure
hold on
for i=1:length(data_class)
[y,x]=ksdensity(data_class{i});
plot(x,y,'LineWidth',3)
end
xlabel('Real distribution')
Epsilon=10^-4;
maxIter=100;
[mu,sigma,ratio]=EMGMM(GMMdata,Epsilon,maxIter);
画出推断得到的分布
for i=1:length(mu)
tempdata=normrnd(mu(i),sigma(i),1,10000);
[y,x]=ksdensity(tempdata);
plot(x,y,'--','LineWidth',3)
end
saveas(gca,)
获取源代码 或 想了解更多内容,欢迎关注公众号:数学建模公会