EM算法 高斯模型 的参数估计 C++实现

因为去百度实习miss模式识别课程的一些课时,自己看了些资料补上,顺便实现以下,让自己有个更深的印象。


高斯分布:


GMM 混合高斯模型  简易通俗理解:

假如大学生的的男生和女生的身高分别符合高斯分布G1和G2,G1和G2的参数(均值u,方差sigma)都不知道。

现在测得了某高中所有学生的身高,可惜的是测量人只记录了身高值,没有记录男女性别,也无从知道男女生的人数比例pi1,pi2。

以上提到的参数G1,G2,pi1,pi2就构成了一个混合高斯模型。

1. 我们想得到G1和G2的参数,即男女生的平均身高和变化幅度,还有男女生人数的占比,怎么算?

通过最大似然估计进行参数估计,最大似然直接求导困难,改用迭代求解,迭代求解过程称为 EM算法,该方法的缺点是只能求局部最优。

【最大似然估计就是让概率分布参数最符合观察到的样本】

2.GMM作用是什么?

给出某个高中生的身高,估计其为男生和女生的概率各是多少,做性别估计或分类。

参考:

http://blog.csdn.net/zouxy09/article/details/8537620

http://amberlife.net/2012/07/机器学习-gmm心得体会/


EM Expectation Maximization过程:

EM算法包括两个过程:

E 已知参数 lamda,求概率(期望)P;

M 根据算出的P,求新的lamda以使出现P的概率最大。



代码实现

代码下载:

http://download.csdn.net/detail/hzq20081121107/6966067

程序的使用:

1

在对话框中随机点一些点,这些点为观察样本,其的横坐标为观察样本的值。

2

在编辑框中输入想聚类的数量k,然后点击“模型初始化”按钮,程序将产生k个高斯分布,k个高斯分布的均值u设为前三个样本的观察值,方差sigma设为1.

3

点击当前高斯曲线,绘制k条当前参数下的高斯曲线。

4

点击“EM一次”按钮,进行一次EM迭代,会绘制一次当前参数下的k条高斯曲线。

多次点击,显示迭代效果。


推理参考:

http://amberlife.net/2012/07/机器学习-gmm心得体会/

http://blog.csdn.net/abcjennifer/article/details/8198352

http://wenku.baidu.com/link?url=hkSwNnMmils9NI7LWcjHHgahXkyskK1xg9iZzFqLBPZwgtRKECCrbVLHnDrB0WVPbIJa0ma2UBQv7T8Twp_HSuBfZ4cGtLHz9lfUVmQ4jzm

http://blog.csdn.net/zouxy09/article/details/8537620


加深学习参看: GMM高斯混合模型学习(2)

实现过程:

设计一个ClassGMM类:

ClassGMM.h

#pragma once
struct struct_sampleNode
{
    float x;
};
struct struct_GM
{
	float Nk;
	float piK;
	float u;
	float sigma;
};
class ClassGMM
{
public:
	struct_GM* pGM;
	struct_sampleNode* pSampleNode;
	int numSample;
	int numClass;
	ClassGMM(void);
	~ClassGMM(void);
	void clear(void);
	int EM(int nTimes);
	float getGaussRatio(float u , float sigma , float x);
};
ClassGMM.cpp

#include "StdAfx.h"
#include <cmath>
#include "ClassGMM.h"

#define PI acos(-1.0)


ClassGMM::ClassGMM(void)
{
	numSample=0;
	numClass=0;
	pGM=NULL;
	pSampleNode=NULL;
}


ClassGMM::~ClassGMM(void)
{
	if(pGM!=NULL)
	{
	    delete []pGM;
		pGM=NULL;
	}
	if(pSampleNode!=NULL)
	{
	    delete []pSampleNode;
		pSampleNode=NULL;
	}
}


void ClassGMM::clear(void)
{
	numSample=0;
	numClass=0;
	if(pGM!=NULL)
	{
		delete []pGM;
		pGM=NULL;
	}
	if(pSampleNode!=NULL)
	{
		delete []pSampleNode;
		pSampleNode=NULL;
	}
}

//EM迭代 nTimes
int ClassGMM::EM(int nTimes)
{
	int i,j,k;
	//开辟空间
	float **ppR;
	ppR=new float* [numSample];
	for(i=0;i<numSample;i++)
	{
	    ppR[i]=new float[numClass];
	}


	while(--nTimes>=0)
	{
		//计算r(i,k),用ppR[i,k]表示
		for(i=0;i<numSample;i++)
		{
			float *piKN;
			piKN=new float[numClass];//
			float sum_piKN=0;
			for(k=0;k<numClass;k++)
			{
				piKN[k]=pGM[k].piK*getGaussRatio(pGM[k].u,pGM[k].sigma,pSampleNode[i].x);
				sum_piKN+=piKN[k];
			}
			for(k=0;k<numClass;k++)
			{
				ppR[i][k]=piKN[k]/sum_piKN;
			}
			delete []piKN;
		}
		//利用r(i,k)求高斯参数,Nk,piK,u,sigma,
		//Nk
		for(k=0;k<numClass;k++)
		{
			pGM[k].Nk=0;
			for(i=0;i<numSample;i++)
			{
				pGM[k].Nk += ppR[i][k];
			}
		}
		//piK
		float sumNK=0;
		for(k=0;k<numClass;k++)
		{
			sumNK+=pGM[k].Nk;
		}
		for(k=0;k<numClass;k++)
		{
			pGM[k].piK=pGM[k].Nk/sumNK;
		}
		//u
		for(k=0;k<numClass;k++)
		{
			pGM[k].u=0;
			for(i=0;i<numSample;i++)
			{
				pGM[k].u+=pSampleNode[i].x*ppR[i][k];
			}
			pGM[k].u/=pGM[k].Nk;
		}
		//sigma
		for(k=0;k<numClass;k++)
		{
			float sum_temp=0;
			for(i=0;i<numSample;i++)
			{
				sum_temp+=ppR[i][k]*(pSampleNode[i].x-pGM[k].u)*(pSampleNode[i].x-pGM[k].u);
			}
			sum_temp/=pGM[k].Nk;
			pGM[k].sigma=sqrt(sum_temp);
		}
	}
	//销毁空间
	for(i=0;i<numSample;i++)
	{
		delete []ppR[i];
	}
	delete []ppR;

	return 0;
}

//求高斯概率密度
float ClassGMM::getGaussRatio(float u , float sigma , float x)
{
	double fenMu=sqrt(2*PI)*sigma;
	double zhiShu=-(x-u)*(x-u)/(2*sigma*sigma);
	double fenZi=exp(zhiShu);
	return fenZi/fenMu;
}

对话框中的一些主要函数:

//GMMO的初始化
void CGMM_testDlg::OnBnClickedButton3()
{
	CRect clientRc;
	GetClientRect(&clientRc);
	int i,j,k;
	//GMMO的初始化
	GMMO.clear();
	//样本赋值
	GMMO.numSample=numPos;
	GMMO.pSampleNode=new struct_sampleNode[numPos];
	for(i=0;i<GMMO.numSample;i++)
	{
	    GMMO.pSampleNode[i].x=(pos_g[i].x)/100.0;
	}
	//高斯分布初始化
	GMMO.numClass=2;
	UpdateData(TRUE);
	if(m_numClass!=0)
	{GMMO.numClass=m_numClass;}
	GMMO.pGM=new struct_GM[GMMO.numClass];
	for(k=0;k<GMMO.numClass;k++)
	{
		GMMO.pGM[k].u=GMMO.pSampleNode[k].x;
		GMMO.pGM[k].sigma=1;
		GMMO.pGM[k].piK=0.5;
		GMMO.pGM[k].Nk=GMMO.pGM[k].piK;
	}
}
//EM过程
void CGMM_testDlg::OnBnClickedButton2()
{
	//GMM模型迭代 10000次
	GMMO.EM(1);
	drawGMM();
}
//显示正态概率密度图
void CGMM_testDlg::OnBnClickedButton4()
{
	CDC *pDC=GetDC();
	CRect rec_window;
	CPen  myPen,*oldPen;
	myPen.CreatePen(1,2,RGB(255,0,0));
	oldPen=pDC->SelectObject(&myPen);
	GetClientRect(&rec_window);
	int x;
	int i,j,k;
	float u,sigma;
	for(k=0;k<GMMO.numClass;k++)
	{
		u=GMMO.pGM[k].u;
		sigma=GMMO.pGM[k].sigma;
		for(x=rec_window.left ;x<rec_window.right;x++)
		{
			int t=100;
			int temp1=rec_window.Height();
			float temp2=t*getGaussRou(x/100.0,u,sigma);
			float temp3=temp1-temp2;
			if(int(u*100)==x)
				x=x;
			pDC->MoveTo(x,rec_window.Height()- t*getGaussRou(x/100.0,u,sigma));
			pDC->LineTo(x+1,rec_window.Height()-t*getGaussRou((x+1)/100.0,u,sigma));
		}
	}
	pDC->SelectObject(oldPen);
	myPen.DeleteObject();
	ReleaseDC(pDC);
}
//绘制各高斯分布的曲线
int CGMM_testDlg::drawGMM(void)
{
	int i,j,k;
	for(k=0;k<GMMO.numClass;k++)
	{drawGauss(GMMO.pGM[k].u,GMMO.pGM[k].sigma);}
	return 0;
}
//绘制期望为u,方差为sigma的正态分布曲线
void CGMM_testDlg::drawGauss(float u, float sigma)
{
	CDC *pDC=GetDC();
	CRect rec_window;
	GetClientRect(&rec_window);
	int x;
	for(x=rec_window.left ;x<rec_window.right;x++)
	{
		int t=100;
		int temp1=rec_window.Height();
		float temp2=t*getGaussRou(x/100.0,u,sigma);
		float temp3=temp1-temp2;
		if(int(u*100)==x)
			x=x;
		pDC->MoveTo(x,rec_window.Height()- t*getGaussRou(x/100.0,u,sigma));
		pDC->LineTo(x+1,rec_window.Height()-t*getGaussRou((x+1)/100.0,u,sigma));
	}
}



//左键标点
void CGMM_testDlg::OnLButtonDown(UINT nFlags, CPoint point)
{
	// TODO: 在此添加消息处理程序代码和/或调用默认值
	int r=2;
	CDC *pDC=GetDC();
	CPoint posNear;
	int numTemp=1+rand()%13;
	for(int i=0;i<numTemp;i++)
	{
		int rr=100;
		posNear.x=point.x+rand()%rr-rr/2;
		posNear.y=point.y+rand()%rr-rr/2;
		pDC->Ellipse(posNear.x-r,posNear.y-r,posNear.x+r,posNear.y+r);
		pos_g[numPos++]=posNear;
	}
	CDialogEx::OnLButtonDown(nFlags, point);
}



评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值