本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 。
内容可能有不到之处,欢迎交流。
未经本人允许禁止转载。
模型
如下为模型:
这个模型中,参数和协方差服从正太逆Wishart先验。
根据模型,给出所有变量的联合似然,即:
公式推理
因变量和权重的联合概率分布可表示为:
求对数:
其中,
EM求解
令:
则:
编程
下面,给出了EM算法迭代的核心代码:
/**EM UPDATE
* @author Qianyang
* ****/
public static Map<Integer, RealMatrix> expectationMaximizationUpdating(double [] meaninitial,RealMatrix spmpling_inverse_wishartinitial,double varianceinitial,Map<Integer,RealMatrix > xMatrixList,Map<Integer,RealMatrix > ymatrix,double epsilon) {
//总体均值 mean input
double [] mean = meaninitial;
//总体协方差 covariance input
RealMatrix spmpling_inverse_wishart = spmpling_inverse_wishartinitial;
//EStep
Map<Integer,List<RealMatrix>> edata = CalculateExpectaction(mean,spmpling_inverse_wishart,varianceinitial,xMatrixList,ymatrix);
Map<Integer,RealMatrix > wMatrixList = new HashMap<Integer,RealMatrix >();
Map<Integer,RealMatrix > covariancematrixlist = new HashMap<Integer,RealMatrix >();
for( int itemnumber : edata.keySet() ){
wMatrixList.put(itemnumber, edata.get(itemnumber).get(0));
covariancematrixlist.put(itemnumber, edata.get(itemnumber).get(1));
}
//Mstep
MStepData mStepData = CalculateMaximization(wMatrixList,covariancematrixlist,xMatrixList,ymatrix);
RealMatrix meanupdate = mStepData.getMeanupdate();
RealMatrix covarianceupdate = mStepData.getCovarianceupdate();
double varianceupdate = mStepData.getVarianceupdate();
System.out.println(mStepData.getVarianceupdate());
//获取单个产品的w矩阵及方差矩阵
MStepData wMatrixandvariance = new MStepData();
Map<Integer,MStepData > wMatrixandvariancemap = new HashMap<Integer,MStepData >();
if (Math.abs(varianceupdate - varianceinitial) < epsilon) {
System.out.println("meanupdate:"+meanupdate+"\tcovarianceupdate:"+covarianceupdate+"\tvarianceupdate:"+varianceupdate);
// for( int itemnumber : edata.keySet() ){
// System.out.println(itemnumber+":\t w:"+wMatrixList.get(itemnumber));
// }
}else{
iter++;
System.out.println("the current iter:\t"+iter);
meaninitial = meanupdate.getColumnVector(0).toArray();
spmpling_inverse_wishartinitial = covarianceupdate;
varianceinitial = varianceupdate;
expectationMaximizationUpdating(meanupdate.getColumnVector(0).toArray(),covarianceupdate,varianceupdate,xMatrixList,ymatrix,epsilon);
}
return wMatrixList;
}
其中,E步的代码为:
/**E-STEP
* mean is the value that sampling from multivariate normal distribution
* spmpling_inverse_wishart is the value that sampling from inverse wishart distribution
* variance is the our difine.
* @author Qianyang
* ****/
private static Map<Integer,List<RealMatrix>> CalculateExpectaction(double [] mean,RealMatrix spmpling_inverse_wishart,double variance,Map<Integer,RealMatrix > xmatrix,Map<Integer,RealMatrix > ymatrix) {
Map<Integer,List<RealMatrix>> weMap=new HashMap<Integer,List<RealMatrix>>();
for( int itemnumber : ymatrix.keySet() ){
RealMatrix meanmatrix = new Array2DRowRealMatrix(mean);
//计算逆矩阵
RealMatrix inverse_inverse_wishartMatrix = inverseMatrix(spmpling_inverse_wishart);
//计算x*x的和除以方差
RealMatrix xtransposematrix = xmatrix.get(itemnumber).transpose();
RealMatrix xxMatrix = xmatrix.get(itemnumber).preMultiply(xtransposematrix).scalarMultiply(1/variance);
RealMatrix inverse_first = inverseMatrix(inverse_inverse_wishartMatrix.add(xxMatrix));
RealMatrix second = xmatrix.get(itemnumber).transpose().multiply(ymatrix.get(itemnumber).transpose()).scalarMultiply(1/variance).add(inverse_inverse_wishartMatrix.multiply(meanmatrix));
RealMatrix wMatrix = inverse_first.multiply(second);
//下面对协方差矩阵更新进行计算
RealMatrix covariancematrix = inverseMatrix(inverse_inverse_wishartMatrix.add(xxMatrix));
// System.out.println("covariancematrix:"+covariancematrix);
List<RealMatrix> listRealMatrix = new ArrayList<RealMatrix>();
//获取的w及协方差
listRealMatrix.add(wMatrix);
listRealMatrix.add(covariancematrix);
weMap.put(itemnumber, listRealMatrix);
}
return weMap;
}
M步的代码为:
/**M-STEP
* mean is the value that sampling from multivariate normal distribution
* spmpling_inverse_wishart is the value that sampling from inverse wishart distribution
* variance is the our define.
* @author Qianyang
* ****/
private static MStepData CalculateMaximization(Map<Integer,RealMatrix > wMatrixList,Map<Integer,RealMatrix > covariancematrixlist,Map<Integer,RealMatrix > xmatrix,Map<Integer,RealMatrix > ymatrix) {
//uw update
RealMatrix meanupdate = updatew(wMatrixList).scalarMultiply(1/Double.valueOf(wMatrixList.size()));
//covariance update
RealMatrix covarianceupdate = updatecovariance(wMatrixList,covariancematrixlist,meanupdate).scalarMultiply(1/Double.valueOf(wMatrixList.size()));
//variance update
double varianceupdate = updatevariance(wMatrixList,xmatrix,ymatrix);
List<RealMatrix> listRealMatrix = new ArrayList<RealMatrix>();
MStepData datamodel = new MStepData();
datamodel.setMeanupdate(meanupdate);
datamodel.setCovarianceupdate(covarianceupdate);
datamodel.setVarianceupdate(varianceupdate);
return datamodel;
}
该算法的完整代码在本人的Github上:https://github.com/soberqian/NIWSamplingProcess