本人编写的一维em算法
初学em算法时被各种公式吓到,学的过程也不是很顺利(本人数学渣渣),好不容易弄懂了大概,现在来编写个程序来检验下
import scala.math
object Main {
val data=Array(1.0,1.3,2.2,2.6,2.8,5.0,7.3,7.4,7.5,7.7,7.9)//点的数据,这里数据只要一维,当然可以为n维
var means=Array(4.0,8.0)//均值
var variances=Array(1.0,1.0)//方差
var probs=Array(0.5,0.5)//数据在每个类的概率,这里
var p=Array.ofDim[Double](2,11)//每个点属于每个类的后验概率
//求得一个特定的点在相应的高斯函数中的概率
def gaussianPro(point:Double,means:Double,probs:Double):Double={
1/(math.sqrt(2*math.Pi)*math.sqrt(probs))*math.exp(-(point-means)*(point-means)/(2*probs))
}
//计算每个点属于每个类的后验概率
def Computepostpro(){
//计算分母
var denominator=new Array[Double](data.size)
for(i<- 0 until data.length){
for(j<- 0 until probs.length)
denominator(i)+=gaussianPro(data(i), means(j), variances(j))*probs(j)
}
//denominator.foreach(println)
//计算每个点属于每个类的后验概率
for(i<- 0 until probs.length){
for(j<- 0 until data.length)
p(i)(j)=probs(i)*gaussianPro(data(j), means(i), variances(i))/denominator(j)
}
}
def GuessMean(){
var m1=new Array[Double](2)
var m2=new Array[Double](2)
for(i<- 0 until 2){
p(i).foreach(m2(i)+=_)
p(i).zip(data).foreach(f=>m1(i)+=f._1*f._2)
means(i)=m1(i)/m2(i)
}
}
def Guessvariances(){
var m1=new Array[Double](2)
var m2=new Array[Double](2)
for(i<- 0 until 2){
p(i).foreach(m2(i)+=_)
p(i).zip(data).foreach(f=>m1(i)+=f._1*(f._2-means(i))*(f._2-means(i)))
variances(i)=m1(i)/m2(i)
}
}
def ComputeProbs(){
for(i<-0 until 2){
probs(i)=p(i).reduce(_+_)/p(i).length
}
}
def E_Step(){
Computepostpro()
}
def M_Step(){
GuessMean()
Guessvariances()
ComputeProbs()
}
def Threshold(localmean:Array[Double],localvariances:Array[Double]):Boolean={
var m1,m2=0.0
for(i<-0 until means.length){
m1+=localmean(i)-means(i)
m2+=localvariances(i)-variances(i)
}
(math.abs(m1)<0.01)&&(math.abs(m2)<0.01)
}
def Output(){
for(i<-0 until 2)
println("第"+i+"个类的均值为:"+means(i)+" 方差为:"+variances(i)+" 选择这个类的概率为:"+probs(i));
}
def main(args: Array[String]): Unit = {
var k=20//迭代次数
var f=true
while(k>0&&f)
{
k-=1
var localmean=means
var localvariances=variances
E_Step()
M_Step()
f=Threshold(localmean, localvariances)
}
Output
}
}