scala下串行实现em算法

本人编写的一维em算法

初学em算法时被各种公式吓到,学的过程也不是很顺利(本人数学渣渣),好不容易弄懂了大概,现在来编写个程序来检验下

import scala.math
object Main {
  val data=Array(1.0,1.3,2.2,2.6,2.8,5.0,7.3,7.4,7.5,7.7,7.9)//点的数据,这里数据只要一维,当然可以为n维
  var means=Array(4.0,8.0)//均值
  var variances=Array(1.0,1.0)//方差
  var probs=Array(0.5,0.5)//数据在每个类的概率,这里
  var p=Array.ofDim[Double](2,11)//每个点属于每个类的后验概率
  //求得一个特定的点在相应的高斯函数中的概率
  def gaussianPro(point:Double,means:Double,probs:Double):Double={
    1/(math.sqrt(2*math.Pi)*math.sqrt(probs))*math.exp(-(point-means)*(point-means)/(2*probs))
  }
  //计算每个点属于每个类的后验概率
  def Computepostpro(){
    //计算分母
    var denominator=new Array[Double](data.size)
    for(i<- 0 until data.length){
     for(j<- 0 until probs.length)
       denominator(i)+=gaussianPro(data(i), means(j), variances(j))*probs(j)
    }
    //denominator.foreach(println)
    //计算每个点属于每个类的后验概率
    for(i<- 0 until probs.length){
       for(j<- 0 until data.length)
          p(i)(j)=probs(i)*gaussianPro(data(j), means(i), variances(i))/denominator(j)
    }
  }
  def GuessMean(){
    var m1=new Array[Double](2)
    var m2=new Array[Double](2)
    for(i<- 0 until 2){
     p(i).foreach(m2(i)+=_)
     p(i).zip(data).foreach(f=>m1(i)+=f._1*f._2)
     means(i)=m1(i)/m2(i)
    }
  }
  def Guessvariances(){
    var m1=new Array[Double](2)
    var m2=new Array[Double](2)
    for(i<- 0 until 2){
     p(i).foreach(m2(i)+=_)
     p(i).zip(data).foreach(f=>m1(i)+=f._1*(f._2-means(i))*(f._2-means(i)))
     variances(i)=m1(i)/m2(i)
    }
  }
  def ComputeProbs(){
    for(i<-0 until 2){
      probs(i)=p(i).reduce(_+_)/p(i).length
    }
  }
  def E_Step(){
    Computepostpro()
  }
  def M_Step(){
    GuessMean()
    Guessvariances()
    ComputeProbs()
  }   
  def Threshold(localmean:Array[Double],localvariances:Array[Double]):Boolean={
    var m1,m2=0.0
    for(i<-0 until means.length){
      m1+=localmean(i)-means(i)
      m2+=localvariances(i)-variances(i)
    }
    (math.abs(m1)<0.01)&&(math.abs(m2)<0.01)
  }
  def Output(){
    for(i<-0 until 2)
		   println("第"+i+"个类的均值为:"+means(i)+" 方差为:"+variances(i)+" 选择这个类的概率为:"+probs(i));
  }
  def main(args: Array[String]): Unit = {
    var k=20//迭代次数  
    var f=true
    while(k>0&&f)
   {
      k-=1
     var localmean=means
     var localvariances=variances
     E_Step()
     M_Step()
     f=Threshold(localmean, localvariances)
   }
    Output
  }
  
}



  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值