1、算法基本原理:
- EM算法一般用于存在隐变量或潜在变量的概率模型,可以算是一种含有隐的概率模型参数的极大似然估计法;
- 假设 θ 为模型的参数, 为模型的观测数据, γ 模型中存在的隐藏变量,EM算法的是通过最大化观测数据 logP(Y|θ) 的方法来求出 θ 的极大似然估计,可以转化为表达式: θ^=argmaxθ(logP(Y|θ))
- 经过转化,可以将问题转化为最大化 E(γ) 的问题,即 θ^=argmaxγ(E(γ)) 。
2、算法推导过程:
- 根据极大似然法的原理,我们的目标是极大化观测数据
Y
关于参数
θ 的对数似然函数,即:L(θ)=logP(Y|θ)=log∑γP(Y,γ|θ)=log(∑λP(Y|γ,θ)P(Z|θ)) - 因为EM算法是通过迭代的办法逐步接近极大
L(θ)
的,假设在第
i
次迭代后
θi ,此我们希望能够使 L(θ)−L(θ(i))≥0L(θ)−L(θi)=log(∑γP(Y|γ,θ)P(γ|θ))−log(P(Y|θi)=log(∑γP(γ|Y,θi)P(Y|γ,θ)P(γ|θ)P(γ|Y,θi))−logP(Y|θi)≥∑γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi))−logP(Y|θi)=∑γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi))令B(θ,θi)=L(θi)+∑γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi))则L(θ)≥B(θ,θi)即函数 B(θ,θi) 是 L(θ) 的一个下界,并且易知: L(θi)≥B(θi,θi) ,因此可以使 B(θ,θi) 增大的 θ 也可以使 L(θ) 增大,为了使 L(θ) 有尽可能大的增大,选择 θi+1 使 B(θ,θi) 打到极大,即:θ(i+1)=argmaxθB(θ,θi)上式可以改写为:θ(i+1)=argmaxθ(L(θi)+∑γP(γ|Y,θi)log(P(Y|γ,θ)P(γ|θ)P(γ|Y,θi)logP(Y|θi)))=argmaxθ∑γP(γ|Y),θilog(P(Y|γ,θ)P(γ|θ))=argmaxθ∑γP(γ|Y,θi)log(P(Y,γ|θ))=argmaxθQ(θ,θi)3、EM算法收敛性证明: 根据对数函数函数性质:若 P(Y|θi) 单调递增且收敛到某一值则 Q(θ,θi) 收敛。 单调性:P(Y|θ)=P(Y,γ|θ)P(γ|Y,θ)logP(Y|θ)=logP(Y,γ|θ)−logP(γ|Y,θ)Q(θ,θi)=∑γlogP(Y,γ|θ)P(γ|Y,θi)令H(θ,θi)=∑γlogP(γ|Y,θ)P(γ|Y,θi)于是对数似然函数可以写成:logP(Y|θ)=Q(θ,θi)−H(θ,θi)上式中 θ 分别取为 θi 和 θi+1 并相减,有:logP(Y|θi+1)−logP(Y|θi)=[Q(θi+1,θi)−Q(θi,θi)]−[H(θi+1,θi)−H(θi,θi)]因为 θi+1 使Q(\theta,\theta^{i})达到极大,所以有:Q(θi+1,θi)−Q(θi,θi)≥0其第2项,可以推导得出:H(θi+1,θi)−H(θi,θi)=∑γ(logp(γ|Y,θi+1)P(γ|Y,θi))P(γ|Y,θi)≤log(∑γP(γ|Y,θi+1)P(γ|Y,θi)P(γ|Y,θi))=log(P(γ|Y,θi+1))=0又因为 P(Y|θi) 有界,所以 L(θi)=log(P(Y|θi)) 收敛到某一值 L∗ 。
4、算法步骤:
- 选择参数的初值 θ0 ,开始迭代;
- E步:记
θi
为第
i
次迭代参数
θ 的估计值,在第 i 次迭代的E步,计算:Q(θ,θi)=Eγ[logP(Y,γ|θ)|Y,θ] =∑γlog(P(Y,γ|θ)P(γ|Y,θi)) - M步:求使
Q(θ,θi)
极大化的
θ
,确定第
i+1
次迭代的参数的估计值
θi+1
θi+1=argmaxθQ(θ,θi)
-重复第E步和第M步,直到对于较小的正数 ξ1 , ξ2 ,若满足 :||θi+1−θi||≤ξq或||Q(θi+1,θi)−Q(θi,θi)||≤ξ2则停止迭代。
package binorandom;
public class binomain {
public static void main(String[] args) {
int[] b=new int[1000];
for (int i=0;i<1000;i++){
b[i]=binorandom.getBinomial(1, 0.4);
}
int[] a=new int[1000];
for ( int i=0;i<999;i++){
if (b[i]==1){
a[i]=binorandom.getBinomial(1,0.5);
}
if(b[i]==0){
a[i]=binorandom.getBinomial(1,0.6);
}
System.out.print(a[i]+" ");
}
System.out.print(a[999]);
}
}
package binorandom;
public class binorandom {
public static int getBinomial(int n, double p) {
int x = 0;
for(int i = 0; i < n; i++) {
if(Math.random() < p)
x++;
}
return x;
}
}
//生成数据集合
package EMpackage;
import java.util.Scanner;
public class EMmain {
public static void main(String[] args){
System.out.println("请输入观测值个数");
Scanner input=new Scanner(System.in);
int datanumber=input.nextInt();
System.out.println("请输入观测值(0或者1):");
Scanner input1=new Scanner(System.in);
int[] obdata=new int[datanumber];
for(int i=0; i<datanumber;i++){
obdata[i]=input1.nextInt();
}
System.out.println("您输入的是:"+" ");
for (int b=0;b<datanumber-1;b++){
System.out.print(obdata[b]+" ");
}
System.out.println(obdata[datanumber-1]);
double[] original=new double[3];
original=ori.original();
double eq=ori.eq();
System.out.println("初始条件为:"+" "+original[0]+" "+original[1]+" "+original[2]);
System.out.println("停止条件为:"+" "+eq);
input1.close();
input.close();
double[] original1=new double[3];
original1=EM.original1(original, obdata, datanumber);
int x=0;
while (euclid(minus(original1,original))>eq){
original=original1;
original1=EM.original1(original,obdata,datanumber);
x=x+1;
}
System.out.println("pi="+original1[0]+"\n"+"p="+original1[1]+"\n"+"q="+original1[2]+"\n"+x);
}
private static double euclid(double[] x) {
double sum=0;
for (int i=0;i<3;i++){
sum=sum+Math.pow(x[i], 2);
}
double euclid=Math.sqrt(sum);
return euclid;
}
private static double[] minus(double[] x,double[] y) {
double[] temp=new double[3];
for (int i=0;i<3;i++){
temp[i]=x[i]-y[i];
}
return temp;
}
}
package EMpackage;
public class EM {
public static double[] original1(double[] original,int[] obdata,int datanumber){
double[] ybl=new double[datanumber];
double[] uybl=new double[datanumber];
double[] l=new double[datanumber];
double datanumber1=datanumber;
for (int i=0;i<datanumber;i++){
ybl[i]=(original[0]*Math.pow(original[1],obdata[i] )*Math.pow(1-original[1],1-obdata[i] ))/(original[0]*Math.pow(original[1],obdata[i])*Math.pow((1-original[1]),(1-obdata[i]))+(1-original[0])*Math.pow(original[2],obdata[i])*Math.pow((1-original[2]),(1-obdata[i])));
uybl[i]=obdata[i]*(original[0]*Math.pow(original[1],obdata[i] )*Math.pow(1-original[1],1-obdata[i] ))/(original[0]*Math.pow(original[1],obdata[i])*Math.pow((1-original[1]),(1-obdata[i]))+(1-original[0])*Math.pow(original[2],obdata[i])*Math.pow((1-original[2]),(1-obdata[i])));
l[i]=1;
}
double[] original1=new double[3];
original1[0]=(1/datanumber1)*sum(ybl,datanumber);
original1[1]=(sum(uybl,datanumber)/sum(ybl,datanumber));
original1[2]=(sum(ybl,datanumber)-sum(uybl,datanumber))/(sum(l,datanumber)-sum(ybl,datanumber));
return original1;
}
private static double sum(double[] ybl,int datanumber) {
double sum=0;
for (int i=0;i<datanumber;i++){
sum=sum+ybl[i];
}
return sum;
}
}
package EMpackage;
import java.util.Scanner;
public class ori{
public static double[] original(){
System.out.println("请输入初始条件条件:"+" ");
Scanner input=new Scanner(System.in);
double original[]=new double[3];
for(int d=0; d<3;d++){
original[d]=input.nextDouble();
}
return original;
}
public static double eq(){
System.out.println("请输入停止条件:"+" ");
Scanner input=new Scanner(System.in);
double eq=input.nextDouble();
return eq;
}
}
//EM算法主程序
实验结果及实例分析
多次运算结果对比:
原始系数pi,p,q(0.4、0.5、0.6):
初始迭代系数 | (0.5、0.5、0.5) | (0.4、0.4、0.4) | (0.4、0.4、0.5) | (0.5、0.4、0.6) | (0.4、0.5、0.4) | (0.5、0.4、0.5) | (0.5、0.6、0.4) |
---|---|---|---|---|---|---|---|
运算结果 | (0.5、0.73、0.32) | (0.54、0.84、0.19) | (0.55、0.84、0.19) | (0.56、0.77、0.29) | (0.56、0.85、0.19) | (0.56、0.77、0.30) | (0.56、0.76、0.30) |
原始系数pi,p,q(0.5、0.5、0.5):
初始迭代系数 | (0.4、0.4、0.4) | (0.3、0.4、0.4) | (0.4、0.4、0.5) | (0.4、0.4、0.6) | (0.4、0.5、0.4) | (0.5、0.4、0.3) | (0.5、0.6、0.4) |
---|---|---|---|---|---|---|---|
运算结果 | (0.49、0.7、0.23) | (0.49、0.85、0.14) | (0.5、0.76、0.23) | (0.49、0.75、0.24) | (0.49、0.76、0.23) | (0.49、1.02、-0.02) | (0.5、0.53、0.45) |
从以上两表不难看出EM算法受到初始迭代值的影响十分大,但是其优点在于需要的迭代次数少,收敛速度十分迅速。