EM算法
package EM;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
public class KGaussEM {
private static int excuteCount = 0;
public static void main(String[] args){
int count = 10000;
//混合高斯分布,averageList为多个高斯分布每个高斯分布的均值
List<Double> averageList = new ArrayList<Double>();
averageList.add(0d);
averageList.add(4d);
averageList.add(8d);
averageList.add(12d);
List<Double> sampleList = generateSample(count,averageList);
List<Double> estamateList = new ArrayList<Double>();
for( int i=0;i<averageList.size();i++){
estamateList.add( new Random().nextDouble()*30);
}
while( excuteCount < 100 ){
System.out.println("第"+excuteCount+"次验证");
List<List<Double>> EListList = E(sampleList,estamateList);
estamateList = M( EListList, sampleList);
excuteCount ++;
}
System.out.println(estamateList.toString());
}
//M步,最大似然
private static List<Double> M(List<List<Double>> EListList,List<Double> sampleList){
int count = EListList.get(0).size();
List<Double> newEstamateList = new ArrayList<Double>();
for( int i=0;i<count ;i++){
double allCount = 0d;
double weightCount = 0d;
for(int j=0;j<EListList.size();j++){
weightCount = weightCount + EListList.get(j).get(i) * sampleList.get(j);
allCount = allCount + EListList.get(j).get(i);
}
newEstamateList.add(weightCount/allCount);
}
return newEstamateList;
}
//E步,求期望
private static List<List<Double>> E(List<Double> sampleList,List<Double> estamateList){
List<List<Double>> EListList = new ArrayList<List<Double>>();
for( Double sample : sampleList){
List<Double> pList = new ArrayList<Double>();
double pCount = 0d;
for(Double estamate : estamateList){
double gsValue = gsValue(estamate,sample);
pList.add(gsValue);
pCount = pCount + gsValue;
}
List<Double> EList = new ArrayList<Double>();
for( Double p: pList){
EList.add(p/pCount);
}
EListList.add(EList);
}
return EListList;
}
private static double gsValue(double u,double x){
return Math.pow(Math.E, -Math.pow(x-u, 2)/2);
}
/**
*
* @param count:每个分布的样例个数
* @param averageList:多个高斯分布,value为高斯分布的均值,delta统一为1
* @return
*/
private static List<Double> generateSample(int count,List<Double> averageList){
List<Double> list = new ArrayList<Double>();
for(int i=0;i<count ;i++){
Random r = new Random();
double random = r.nextGaussian();
for(Double av : averageList){
list.add(random+av);
}
}
return list;
}
}