本文作者:合肥工业大学 管理学院 钱洋 email:1563178220@qq.com 内容可能有不到之处,欢迎交流。
未经本人允许禁止转载
问题背景
最近,在独立实现Gaussian LDA算法时,遇到了Multivariate t distributions。Gaussian LDA对应的论文是:
Das R, Zaheer M, Dyer C. Gaussian lda for topic models with word embeddings[C]//Proceedings of the 53rd Annual Meeting of the Association for Computational Linguistics and the 7th International Joint Conference on Natural Language Processing (Volume 1: Long Papers). 2015, 1: 795-804.
我的另外一篇博客,已对参数学习公式的推理进行了详细说明,即:
https://qianyang-hfut.blog.csdn.net/article/details/79528781
其中,在实现该算法时,需要对每个单词所属的主题的概率进行计算,即通过如下公式:
从上面的公式可以看到,最为主要的是计算t分布函数,其输入是word2Vec产生的词向量,以及均值向量和协方差矩阵。
多元t分布函数的公式如下:
编程实现
在math3中,提供了很多函数,如Gamma函数,高斯函数等;也提供了很多分布,如高斯分布,Beta分布,二项分布等的采样。下面使用math3实现一元t分布的采样:
import org.apache.commons.math3.distribution.TDistribution;
public class Test3 {
public static void main(String[] args) {
//定义自由度
TDistribution t = new TDistribution(5);
for (int i = 0; i < 100; i++) {
System.out.println(t.sample());
}
}
}
程序的输入结果为:
下图为一元t分布对应的分布图。
在math3只提供了一元t分布和其概率密度函数,但并多元t分布的函数。
在实现Gaussian LDA的过程中,我自己又编写了多元t分布的函数,即下面公式对应的代码:
如下为Java代码:
/**The probability density function of the Multivariate t distributions
* Reference: Conjugate Bayesian analysis of the Gaussian distribution
*
*
* @param ArrayRealVector dataPoint
* @param ArrayRealVector meansVector
* @param RealMatrix covarianceMatrix
* @param double degreesOfFreedom
* @return The log probability value
*
*
* @author Qianyang
* ****/
public static double logMultivariateTDensity(ArrayRealVector dataPoint, ArrayRealVector meansVector, RealMatrix covarianceMatrix, double degreesOfFreedom){
LUDecomposition covariance = new LUDecomposition(covarianceMatrix);
double logprob_left = Gamma.logGamma((degreesOfFreedom + dataPoint.getDimension())/2.0) -
(Gamma.logGamma(degreesOfFreedom / 2.0) + 0.5 * Math.log(covariance.getDeterminant()) +
dataPoint.getDimension()/2.0 * (Math.log(degreesOfFreedom) + Math.log(Math.PI)));
// compute x-u
ArrayRealVector var = dataPoint.add(meansVector.mapMultiplyToSelf(-1.0));
// (x-u) to matrix
RealMatrix realMatrix = new Array2DRowRealMatrix(var.getDataRef());
//compute left
double logprob_right = Math.log(1 + realMatrix.transpose().multiply(new LUDecomposition(covarianceMatrix).getSolver().getInverse())
.multiply(realMatrix).getData()[0][0]/degreesOfFreedom);
return Math.exp(logprob_left -(degreesOfFreedom + dataPoint.getDimension())/2.0 * logprob_right);
}
使用方法
如下为使用方法:
ArrayRealVector dataPoint = new ArrayRealVector(new double[]{2,5});
ArrayRealVector meansVector = new ArrayRealVector(new double[]{2,1});
RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(new double[][]{{1, 0}, {0, 1}});
double logt = logMultivariateTDensity(dataPoint,meansVector,covarianceMatrix,10);
System.out.println(logt);
主方法执行这些代码,会输出如下结果。
改写协方差和自由度:
ArrayRealVector dataPoint = new ArrayRealVector(new double[]{2,5});
ArrayRealVector meansVector = new ArrayRealVector(new double[]{2,1});
RealMatrix covarianceMatrix = MatrixUtils.createRealMatrix(new double[][]{{3, 0.8}, {0.8, 5}});
double logt = logMultivariateTDensity(dataPoint,meansVector,covarianceMatrix,17);
System.out.println(logt);
程序输出结果为: