MATLAB有用的混合高斯EM算法,害怕以后找不到

1

X=   zeros(600,2);

 2

X(1:200,:) = normrnd(0,1,200,2);

 3

X(201:400,:) = normrnd(0,2,200,2);

4

X(401:600,:) = normrnd(0,3,200,2);

5

[W,M,V,L] = EM_GM(X,3,[],[],1,[])

下面是程序源码:

查看源代码

打印帮助

001

function[W,M,V,L] = EM_GM(X,k,ltol,maxiter,pflag,Init)

002

% [W,M,V,L] = EM_GM(X,k,ltol,maxiter,pflag,Init)

003

%

004

% EM algorithm for k multidimensional Gaussian mixture estimation

005

%

006

% Inputs:

007

%   X(n,d) - input data, n=number of observations, d=dimension of variable

008

%   k - maximum number of Gaussian components allowed

009

%   ltol - percentage of the log likelihood difference between 2 iterations ([] for none)

010

%   maxiter - maximum number of iteration allowed ([] for none)

011

%   pflag - 1 for plotting GM for 1D or 2D cases only, 0 otherwise ([] for none)

012

%   Init - structure of initial W, M, V: Init.W, Init.M, Init.V ([] for none)

013

%

014

% Ouputs:

015

%   W(1,k) - estimated weights of GM

016

%   M(d,k) - estimated mean vectors of GM

017

%   V(d,d,k) - estimated covariance matrices of GM

018

%   L - log likelihood of estimates

019

%

020

% Written by

021

%   Patrick P. C. Tsui,

022

%   PAMI research group

023

%   Department of Electrical and Computer Engineering

024

%   University of Waterloo,

025

%   March, 2006

026

%

027

  

028

%%%% Validate inputs %%%%

029

ifnargin <= 1,

030

 disp('EM_GM must have at least 2 inputs: X,k!/n')

031

 return

032

elseifnargin == 2,

033

 ltol = 0.1; maxiter = 1000; pflag = 0; Init = [];

034

 err_X = Verify_X(X);

035

 err_k = Verify_k(k);

036

 iferr_X | err_k,return;end

037

elseifnargin == 3,

038

 maxiter = 1000; pflag = 0; Init = [];

039

 err_X = Verify_X(X);

040

 err_k = Verify_k(k);

041

 [ltol,err_ltol] = Verify_ltol(ltol);

042

 iferr_X | err_k | err_ltol,return;end

043

elseifnargin == 4,

044

 pflag = 0;  Init = [];

045

 err_X = Verify_X(X);

046

 err_k = Verify_k(k);

047

 [ltol,err_ltol] = Verify_ltol(ltol);

048

 [maxiter,err_maxiter] = Verify_maxiter(maxiter);

049

 iferr_X | err_k | err_ltol | err_maxiter,return;end

050

elseifnargin == 5,

051

 Init = [];

052

 err_X = Verify_X(X);

053

 err_k = Verify_k(k);

054

 [ltol,err_ltol] = Verify_ltol(ltol);

055

 [maxiter,err_maxiter] = Verify_maxiter(maxiter);

056

 [pflag,err_pflag] = Verify_pflag(pflag);

057

 iferr_X | err_k | err_ltol | err_maxiter | err_pflag,return;end

058

elseifnargin == 6,

059

 err_X = Verify_X(X);

060

 err_k = Verify_k(k);

061

 [ltol,err_ltol] = Verify_ltol(ltol);

062

 [maxiter,err_maxiter] = Verify_maxiter(maxiter);

063

 [pflag,err_pflag] = Verify_pflag(pflag);

064

 [Init,err_Init]=Verify_Init(Init);

065

 iferr_X | err_k | err_ltol | err_maxiter | err_pflag | err_Init,return;end

066

else

067

 disp('EM_GM must have 2 to 6 inputs!');

068

 return

069

end

070

  

071

%%%% Initialize W, M, V,L %%%%

072

t = cputime;

073

ifisempty(Init),

074

 [W,M,V] = Init_EM(X,k); L = 0;

075

else

076

 W = Init.W;

077

 M = Init.M;

078

 V = Init.V;

079

end

080

Ln = Likelihood(X,k,W,M,V);% Initialize log likelihood

081

Lo = 2*Ln;

082

  

083

%%%% EM algorithm %%%%

084

niter = 0;

085

while(abs(100*(Ln-Lo)/Lo)>ltol) & (niter<=maxiter),

086

 E = Expectation(X,k,W,M,V);% E-step

087

 [W,M,V] = Maximization(X,k,E); % M-step

088

 Lo = Ln;

089

 Ln = Likelihood(X,k,W,M,V);

090

 niter = niter + 1;

091

end

092

L = Ln;

093

  

094

%%%% Plot 1D or 2D %%%%

095

ifpflag==1,

096

 [n,d] = size(X);

097

 ifd>2,

098

 disp('Can only plot 1 or 2 dimensional applications!/n');

099

 else

100

 Plot_GM(X,k,W,M,V);

101

 end

102

 elapsed_time = sprintf('CPU time used for EM_GM: %5.2fs',cputime-t);

103

 disp(elapsed_time);

104

 disp(sprintf('Number of iterations: %d',niter-1));

105

end

106

%%%%%%%%%%%%%%%%%%%%%%

107

%%%% End of EM_GM %%%%

108

%%%%%%%%%%%%%%%%%%%%%%

109

  

110

functionE = Expectation(X,k,W,M,V)

111

[n,d] = size(X);

112

a = (2*pi)^(0.5*d);

113

S = zeros(1,k);

114

iV = zeros(d,d,k);

115

forj=1:k,

116

 ifV(:,:,j)==zeros(d,d), V(:,:,j)=ones(d,d)*eps;end

117

 S(j) = sqrt(det(V(:,:,j)));

118

 iV(:,:,j) = inv(V(:,:,j));

119

end

120

E = zeros(n,k);

121

fori=1:n,

122

 forj=1:k,

123

 dXM = X(i,:)'-M(:,j);

124

 pl = exp(-0.5*dXM'*iV(:,:,j)*dXM)/(a*S(j));

125

 E(i,j) = W(j)*pl;

126

 end

127

 E(i,:) = E(i,:)/sum(E(i,:));

128

end

129

%%%%%%%%%%%%%%%%%%%%%%%%%%%%

130

%%%% End of Expectation %%%%

131

%%%%%%%%%%%%%%%%%%%%%%%%%%%%

132

  

133

function[W,M,V] = Maximization(X,k,E)

134

[n,d] = size(X);

135

W = zeros(1,k); M = zeros(d,k);

136

V = zeros(d,d,k);

137

fori=1:k, % Compute weights

138

 forj=1:n,

139

 W(i) = W(i) + E(j,i);

140

 M(:,i) = M(:,i) + E(j,i)*X(j,:)';

141

 end

142

 M(:,i) = M(:,i)/W(i);

143

end

144

fori=1:k,

145

 forj=1:n,

146

 dXM = X(j,:)'-M(:,i);

147

 V(:,:,i) = V(:,:,i) + E(j,i)*dXM*dXM';

148

 end

149

 V(:,:,i) = V(:,:,i)/W(i);

150

end

151

W = W/n;

152

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

153

%%%% End of Maximization %%%%

154

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

155

  

156

functionL = Likelihood(X,k,W,M,V)

157

% Compute L based on K. V. Mardia, "Multivariate Analysis", Academic Press, 1979, PP. 96-97

158

% to enchance computational speed

159

[n,d] = size(X);

160

U = mean(X)';

161

S = cov(X);

162

L = 0;

163

fori=1:k,

164

 iV = inv(V(:,:,i));

165

 L = L + W(i)*(-0.5*n*log(det(2*pi*V(:,:,i))) ...

166

 -0.5*(n-1)*(trace(iV*S)+(U-M(:,i))'*iV*(U-M(:,i))));

167

end

168

%%%%%%%%%%%%%%%%%%%%%%%%%%%

169

%%%% End of Likelihood %%%%

170

%%%%%%%%%%%%%%%%%%%%%%%%%%%

171

  

172

functionerr_X = Verify_X(X)

173

err_X = 1;

174

[n,d] = size(X);

175

ifn<d,

176

 disp('Input data must be n x d!/n');

177

 return

178

end

179

err_X = 0;

180

%%%%%%%%%%%%%%%%%%%%%%%%%

181

%%%% End of Verify_X %%%%

182

%%%%%%%%%%%%%%%%%%%%%%%%%

183

  

184

functionerr_k = Verify_k(k)

185

err_k = 1;

186

if~isnumeric(k) | ~isreal(k) | k<1,

187

 disp('k must be a real integer >= 1!/n');

188

 return

189

end

190

err_k = 0;

191

%%%%%%%%%%%%%%%%%%%%%%%%%

192

%%%% End of Verify_k %%%%

193

%%%%%%%%%%%%%%%%%%%%%%%%%

194

  

195

function[ltol,err_ltol] = Verify_ltol(ltol)

196

err_ltol = 1;

197

ifisempty(ltol),

198

 ltol = 0.1;

199

elseif~isreal(ltol) | ltol<=0,

200

 disp('ltol must be a positive real number!');

201

 return

202

end

203

err_ltol = 0;

204

%%%%%%%%%%%%%%%%%%%%%%%%%%%%

205

%%%% End of Verify_ltol %%%%

206

%%%%%%%%%%%%%%%%%%%%%%%%%%%%

207

  

208

function[maxiter,err_maxiter] = Verify_maxiter(maxiter)

209

err_maxiter = 1;

210

ifisempty(maxiter),

211

 maxiter = 1000;

212

elseif~isreal(maxiter) | maxiter<=0,

213

 disp('ltol must be a positive real number!');

214

 return

215

end

216

err_maxiter = 0;

217

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

218

%%%% End of Verify_maxiter %%%%

219

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

220

  

221

function[pflag,err_pflag] = Verify_pflag(pflag)

222

err_pflag = 1;

223

ifisempty(pflag),

224

 pflag = 0;

225

elseifpflag~=0 & pflag~=1,

226

 disp('Plot flag must be either 0 or 1!/n');

227

 return

228

end

229

err_pflag = 0;

230

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

231

%%%% End of Verify_pflag %%%%

232

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

233

  

234

function[Init,err_Init] = Verify_Init(Init)

235

err_Init = 1;

236

ifisempty(Init),

237

 % Do nothing;

238

elseifisstruct(Init),

239

 [Wd,Wk] = size(Init.W);

240

 [Md,Mk] = size(Init.M);

241

 [Vd1,Vd2,Vk] = size(Init.V);

242

 ifWk~=Mk | Wk~=Vk | Mk~=Vk,

243

 disp('k in Init.W(1,k), Init.M(d,k) and Init.V(d,d,k) must equal!/n')

244

 return

245

 end

246

 ifMd~=Vd1 | Md~=Vd2 | Vd1~=Vd2,

247

 disp('d in Init.W(1,k), Init.M(d,k) and Init.V(d,d,k) must equal!/n')

248

 return

249

 end

250

else

251

 disp('Init must be a structure: W(1,k), M(d,k), V(d,d,k) or []!');

252

 return

253

end

254

err_Init = 0;

255

%%%%%%%%%%%%%%%%%%%%%%%%%%%%

256

%%%% End of Verify_Init %%%%

257

%%%%%%%%%%%%%%%%%%%%%%%%%%%%

258

  

259

function[W,M,V] = Init_EM(X,k)

260

[n,d] = size(X);

261

[Ci,C] = kmeans(X,k,'Start','cluster', ...

262

 'Maxiter',100, ...

263

 'EmptyAction','drop', ...

264

 'Display','off');% Ci(nx1) - cluster indeices; C(k,d) - cluster centroid (i.e. mean)

265

whilesum(isnan(C))>0,

266

 [Ci,C] = kmeans(X,k,'Start','cluster', ...

267

 'Maxiter',100, ...

268

 'EmptyAction','drop', ...

269

 'Display','off');

270

end

271

M = C';

272

Vp = repmat(struct('count',0,'X',zeros(n,d)),1,k);

273

fori=1:n,% Separate cluster points

274

 Vp(Ci(i)).count = Vp(Ci(i)).count + 1;

275

 Vp(Ci(i)).X(Vp(Ci(i)).count,:) = X(i,:);

276

end

277

V = zeros(d,d,k);

278

fori=1:k,

279

 W(i) = Vp(i).count/n;

280

 V(:,:,i) = cov(Vp(i).X(1:Vp(i).count,:));

281

end

282

%%%%%%%%%%%%%%%%%%%%%%%%

283

%%%% End of Init_EM %%%%

284

%%%%%%%%%%%%%%%%%%%%%%%%

285

  

286

functionPlot_GM(X,k,W,M,V)

287

[n,d] = size(X);

288

ifd>2,

289

 disp('Can only plot 1 or 2 dimensional applications!/n');

290

 return

291

end

292

S = zeros(d,k);

293

R1 = zeros(d,k);

294

R2 = zeros(d,k);

295

fori=1:k, % Determine plot range as 4 x standard deviations

296

 S(:,i) = sqrt(diag(V(:,:,i)));

297

 R1(:,i) = M(:,i)-4*S(:,i);

298

 R2(:,i) = M(:,i)+4*S(:,i);

299

end

300

Rmin = min(min(R1));

301

Rmax = max(max(R2));

302

R = [Rmin:0.001*(Rmax-Rmin):Rmax];

303

clf, hold on

304

ifd==1,

305

 Q = zeros(size(R));

306

 fori=1:k,

307

 P = W(i)*normpdf(R,M(:,i),sqrt(V(:,:,i)));

308

 Q = Q + P;

309

 plot(R,P,'r-'); grid on,

310

 end

311

 plot(R,Q,'k-');

312

 xlabel('X');

313

 ylabel('Probability density');

314

else% d==2

315

 plot(X(:,1),X(:,2),'r.');

316

 fori=1:k,

317

 Plot_Std_Ellipse(M(:,i),V(:,:,i));

318

 end

319

 xlabel('1^{st} dimension');

320

 ylabel('2^{nd} dimension');

321

 axis([Rmin Rmax Rmin Rmax])

322

end

323

title('Gaussian Mixture estimated by EM');

324

%%%%%%%%%%%%%%%%%%%%%%%%

325

%%%% End of Plot_GM %%%%

326

%%%%%%%%%%%%%%%%%%%%%%%%

327

  

328

functionPlot_Std_Ellipse(M,V)

329

[Ev,D] = eig(V);

330

d = length(M);

331

ifV(:,:)==zeros(d,d),

332

 V(:,:) = ones(d,d)*eps;

333

end

334

iV = inv(V);

335

% Find the larger projection

336

P = [1,0;0,0]; % X-axis projection operator

337

P1 = P * 2*sqrt(D(1,1)) * Ev(:,1);

338

P2 = P * 2*sqrt(D(2,2)) * Ev(:,2);

339

ifabs(P1(1)) >= abs(P2(1)),

340

 Plen = P1(1);

341

else

342

 Plen = P2(1);

343

end

344

count = 1;

345

step = 0.001*Plen;

346

Contour1 = zeros(2001,2);

347

Contour2 = zeros(2001,2);

348

forx = -Plen:step:Plen,

349

 a = iV(2,2);

350

 b = x * (iV(1,2)+iV(2,1));

351

 c = (x^2) * iV(1,1) - 1;

352

 Root1 = (-b + sqrt(b^2 - 4*a*c))/(2*a);

353

 Root2 = (-b - sqrt(b^2 - 4*a*c))/(2*a);

354

 ifisreal(Root1),

355

 Contour1(count,:) = [x,Root1] + M';

356

 Contour2(count,:) = [x,Root2] + M';

357

 count = count + 1;

358

 end

359

end

360

Contour1 = Contour1(1:count-1,:);

361

Contour2 = [Contour1(1,:);Contour2(1:count-1,:);Contour1(count-1,:)];

362

plot(M(1),M(2),'k+');

363

plot(Contour1(:,1),Contour1(:,2),'k-');

364

plot(Contour2(:,1),Contour2(:,2),'k-');

365

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

366

%%%% End of Plot_Std_Ellipse %%%%

367

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值