几种概率语言模型和参数学习方法

9 篇文章 0 订阅
7 篇文章 0 订阅

 

From:http://blog.csdn.net/yangliuy/article/details/8330640

            http://blog.csdn.net/yangliuy/article/details/8302599

            http://blog.csdn.net/yangliuy/article/details/8457329

 

*********************************************************************************************************************************

第一篇 PLSA及EM算法

[本文PDF版本下载地址 PLSA及EM算法-yangliuy]

前言:本文主要介绍PLSA及EM算法,首先给出LSA(隐性语义分析)的早期方法SVD,然后引入基于概率的PLSA模型,其参数学习采用EM算法。接着我们分析如何运用EM算法估计一个简单的mixture unigram 语言模型和混合高斯模型GMM的参数,最后总结EM算法的一般形式及运用关键点。对于改进PLSA,引入hyperparameter的LDA模型及其Gibbs Sampling参数估计方法放在本系列后面的文章LDA及Gibbs Samping介绍。


1 LSA and SVD

LSA(隐性语义分析)的目的是要从文本中发现隐含的语义维度-即“Topic”或者“Concept”。我们知道,在文档的空间向量模型(VSM)中,文档被表示成由特征词出现概率组成的多维向量,这种方法的好处是可以将query和文档转化成同一空间下的向量计算相似度,可以对不同词项赋予不同的权重,在文本检索、分类、聚类问题中都得到了广泛应用,在基于贝叶斯算法及KNN算法的newsgroup18828文本分类器的JAVA实现基于Kmeans算法、MBSAS算法及DBSCAN算法的newsgroup18828文本聚类器的JAVA实现系列文章中的分类聚类算法大多都是采用向量空间模型。然而,向量空间模型没有能力处理一词多义和一义多词问题,例如同义词也分别被表示成独立的一维,计算向量的余弦相似度时会低估用户期望的相似度;而某个词项有多个词义时,始终对应同一维度,因此计算的结果会高估用户期望的相似度。


LSA方法的引入就可以减轻类似的问题。基于SVD分解,我们可以构造一个原始向量矩阵的一个低秩逼近矩阵,具体的做法是将词项文档矩阵做SVD分解




  其中是以词项(terms)为行, 文档(documents)为列做一个大矩阵. 设一共有t行d列,  矩阵的元素为词项的tf-idf值。然后的r个对角元素的前k个保留(最大的k个保留), 后面最小的r-k个奇异值置0, 得到;最后计算一个近似的分解矩阵




在最小二乘意义下是的最佳逼近。由于最多包含k个非零元素,所以的秩不超过k。通过在SVD分解近似,我们将原始的向量转化成一个低维隐含语义空间中,起到了特征降维的作用。每个奇异值对应的是每个“语义”维度的权重,将不太重要的权重置为0,只保留最重要的维度信息,去掉一些信息“nosie”,因而可以得到文档的一种更优表示形式。将SVD分解降维应用到文档聚类的JAVA实现可参见此文

IIR中给出的一个SVD降维的实例如下:


左边是原始矩阵的SVD分解,右边是只保留权重最大2维,将原始矩阵降到2维后的情况。


2 PLSA

尽管基于SVD的LSA取得了一定的成功,但是其缺乏严谨的数理统计基础,而且SVD分解非常耗时。Hofmann在SIGIR'99上提出了基于概率统计的PLSA模型,并且用EM算法学习模型参数。PLSA的概率图模型如下



 

 

其中D代表文档,Z代表隐含类别或者主题,W为观察到的单词,表示单词出现在文档的概率,表示文档中出现主题下的单词的概率,给定主题出现单词的概率。并且每个主题在所有词项上服从Multinomial 分布,每个文档在所有主题上服从Multinomial 分布。整个文档的生成过程是这样的:

(1) 以的概率选中文档

(2) 以的概率选中主题

(3) 以的概率产生一个单词。

我们可以观察到的数据就是对,而是隐含变量。的联合分布为




分布对应了两组Multinomial 分布,我们需要估计这两组分布的参数。下面给出用EM算法估计PLSA参数的详细推导过程。


3 Estimate parameters in PLSA  by EM

(注:本部分主要参考Tomas Hoffman, Unsupervised Learning by Probabilistic Latent Semantic Analysis.

文本语言模型的参数估计-最大似然估计、MAP及贝叶斯估计一文所述,常用的参数估计方法有MLE、MAP、贝叶斯估计等等。但是在PLSA中,如果我们试图直接用MLE来估计参数,就会得到似然函数




其中是term 出现在文档中的次数。n(di)表示文档di中的总词数。注意这是一个关于的函数,一共有N*K + M*K个自变量(注意这里M表示term的总数,一般文献习惯用V表示),如果直接对这些自变量求偏导数,我们会发现由于自变量包含在对数和中,这个方程的求解很困难。因此对于这样的包含“隐含变量”或者“缺失数据”的概率模型参数估计问题,我们采用EM算法。


EM算法的步骤是:

(1)E步骤:求隐含变量Given当前估计的参数条件下的后验概率。

(2)M步骤:最大化Complete data对数似然函数的期望,此时我们使用E步骤里计算的隐含变量的后验概率,得到新的参数值。

两步迭代进行直到收敛。


先解释一下什么是Incomplete data和complete data。Zhai老师在一篇经典的EM算法Notes中讲到,当原始数据的似然函数很复杂时,我们通过增加一些隐含变量来增强我们的数据,得到“complete data”,而“complete data”的似然函数更加简单,方便求极大值。于是,原始的数据就成了“incomplete data”。我们将会看到,我们可以通过最大化“complete data”似然函数的期望来最大化"incomplete data"的似然函数,以便得到求似然函数最大值更为简单的计算途径。


针对我们PLSA参数估计问题,在E步骤中,直接使用贝叶斯公式计算隐含变量在当前参数取值条件下的后验概率,有




在这个步骤中,我们假定所有的都是已知的,因为初始时随机赋值,后面迭代的过程中取前一轮M步骤中得到的参数值。


在M步骤中,我们最大化Complete data对数似然函数的期望。在PLSA中,Incomplete data 是观察到的,隐含变量是主题,那么complete data就是三元组,其期望是




注意这里是已知的,取的是前面E步骤里面的估计值。下面我们来最大化期望,这又是一个多元函数求极值的问题,可以用拉格朗日乘数法。拉格朗日乘数法可以把条件极值问题转化为无条件极值问题,在PLSA中目标函数就是,约束条件是




由此我们可以写出拉格朗日函数




这是一个关于的函数,分别对其求偏导数,我们可以得到




注意这里进行过方程两边同时乘以的变形,联立上面4组方程,我们就可以解出M步骤中通过最大化期望估计出的新的参数值




解方程组的关键在于先求出,其实只需要做一个加和运算就可以把的系数都化成1,后面就好计算了。

然后使用更新后的参数值,我们又进入E步骤,计算隐含变量 Given当前估计的参数条件下的后验概率。如此不断迭代,直到满足终止条件。

注意到我们在M步骤中还是使用对Complete Data的MLE,那么如果我们想加入一些先验知识进入我们的模型,我们可以在M步骤中使用MAP估计。正如文本语言模型的参数估计-最大似然估计、MAP及贝叶斯估计中投硬币的二项分布实验中我们加入“硬币一般是两面均匀的”这个先验一样。而由此计算出的参数的估计值会在分子分母中多出关于先验参数的preduo counts,其他步骤都是一样的。具体可以参考Mei Qiaozhu 的Notes

 ——————————————————

EM算法解法:

——————————————————

PLSA的实现也不难,网上有很多实现code。

例如这个PLSA的EM算法实现 http://ezcodesample.com/plsaidiots/PLSAjava.txt

主要的类如下(作者Andrew Polar)

 

[java]   view plain copy
  1. //The code is taken from:  
  2. //http://code.google.com/p/mltool4j/source/browse/trunk/src/edu/thu/mltool4j/topicmodel/plsa  
  3. //I noticed some difference with original Hofmann concept in computation of P(z). It is   
  4. //always even and actually not involved, that makes this algorithm non-negative matrix   
  5. //factoring and not PLSA.  
  6. //Found and tested by Andrew Polar.   
  7. //My version can be found on semanticsearchart.com or ezcodesample.com  

 

 

[java]   view plain copy
  1. class ProbabilisticLSA  
  2. {  
  3.     private Dataset dataset = null;  
  4.     private Posting[][] invertedIndex = null;  
  5.     private int M = -1// number of data  
  6.     private int V = -1// number of words  
  7.     private int K = -1// number of topics  
  8.   
  9.     public ProbabilisticLSA()  
  10.     {  
  11.     }  
  12.   
  13.     public boolean doPLSA(String datafileName, int ntopics, int iters)  
  14.     {  
  15.         File datafile = new File(datafileName);  
  16.         if (datafile.exists())  
  17.         {  
  18.             if ((this.dataset = new Dataset(datafile)) == null)  
  19.             {  
  20.                 System.out.println("doPLSA, dataset == null");  
  21.                 return false;  
  22.             }  
  23.             this.M = this.dataset.size();  
  24.             this.V = this.dataset.getFeatureNum();  
  25.             this.K = ntopics;  
  26.               
  27.              //build inverted index  
  28.             this.buildInvertedIndex(this.dataset);  
  29.             //run EM algorithm  
  30.             this.EM(iters);  
  31.             return true;  
  32.               
  33.         }  
  34.         else  
  35.         {  
  36.             System.out.println("ProbabilisticLSA(String datafileName), datafile: " + datafileName + " doesn't exist");  
  37.             return false;  
  38.         }  
  39.     }  
  40.   
  41.     //Build the inverted index for M-step fast calculation. Format:  
  42.     //invertedIndex[w][]: a unsorted list of document and position which word w  
  43.     // occurs.   
  44.     //@param ds the dataset which to be analysis  
  45.     @SuppressWarnings("unchecked")  
  46.     private boolean buildInvertedIndex(Dataset ds)  
  47.     {  
  48.         ArrayList<Posting>[] list = new ArrayList[this.V];  
  49.         for (int k=0; k<this.V; ++k) {  
  50.             list[k] = new ArrayList<Posting>();  
  51.         }  
  52.               
  53.         for (int m = 0; m < this.M; m++)  
  54.         {  
  55.             Data d = ds.getDataAt(m);  
  56.             for (int position = 0; position < d.size(); position++)  
  57.             {  
  58.                 int w = d.getFeatureAt(position).dim;  
  59.                 // add posting  
  60.                 list[w].add(new Posting(m, position));  
  61.             }  
  62.         }  
  63.         // convert to array  
  64.         this.invertedIndex = new Posting[this.V][];  
  65.         for (int w = 0; w < this.V; w++)  
  66.         {  
  67.             this.invertedIndex[w] = list[w].toArray(new Posting[0]);  
  68.         }  
  69.         return true;  
  70.     }  
  71.   
  72.     private boolean EM(int iters)  
  73.     {  
  74.         // p(z), size: K  
  75.         double[] Pz = new double[this.K];  
  76.   
  77.         // p(d|z), size: K x M  
  78.         double[][] Pd_z = new double[this.K][this.M];  
  79.   
  80.         // p(w|z), size: K x V  
  81.         double[][] Pw_z = new double[this.K][this.V];  
  82.   
  83.         // p(z|d,w), size: K x M x doc.size()  
  84.         double[][][] Pz_dw = new double[this.K][this.M][];  
  85.   
  86.          // L: log-likelihood value  
  87.          double L = -1;  
  88.   
  89.          // run EM algorithm  
  90.          this.init(Pz, Pd_z, Pw_z, Pz_dw);  
  91.          for (int it = 0; it < iters; it++)  
  92.          {  
  93.              // E-step  
  94.              if (!this.Estep(Pz, Pd_z, Pw_z, Pz_dw))  
  95.              {  
  96.                  System.out.println("EM,  in E-step");  
  97.              }  
  98.   
  99.              // M-step  
  100.              if (!this.Mstep(Pz_dw, Pw_z, Pd_z, Pz))  
  101.              {  
  102.                  System.out.println("EM, in M-step");  
  103.              }  
  104.   
  105.              L = calcLoglikelihood(Pz, Pd_z, Pw_z);  
  106.              System.out.println("[" + it + "]" + "\tlikelihood: " + L);  
  107.          }  
  108.                   
  109.          //print result  
  110.          for (int m = 0; m < this.M; m++)  
  111.          {  
  112.              double norm = 0.0;  
  113.              for (int z = 0; z < this.K; z++) {  
  114.                  norm += Pd_z[z][m];  
  115.              }  
  116.              if (norm <= 0.0) norm = 1.0;  
  117.              for (int z = 0; z < this.K; z++) {  
  118.                  System.out.format("%10.4f", Pd_z[z][m]/norm);  
  119.              }  
  120.              System.out.println();  
  121.         }   
  122.         return false;  
  123.     }  
  124.      
  125.     private boolean init(double[] Pz, double[][] Pd_z, double[][] Pw_z, double[][][] Pz_dw)  
  126.     {  
  127.         // p(z), size: K  
  128.         double zvalue = (double1 / (doublethis.K;  
  129.         for (int z = 0; z < this.K; z++)  
  130.         {  
  131.             Pz[z] = zvalue;  
  132.         }  
  133.   
  134.         // p(d|z), size: K x M  
  135.         for (int z = 0; z < this.K; z++)  
  136.         {  
  137.             double norm = 0.0;  
  138.             for (int m = 0; m < this.M; m++)  
  139.             {  
  140.                 Pd_z[z][m] = Math.random();  
  141.                 norm += Pd_z[z][m];  
  142.             }  
  143.   
  144.             for (int m = 0; m < this.M; m++)  
  145.             {  
  146.                 Pd_z[z][m] /= norm;  
  147.             }  
  148.         }  
  149.   
  150.         // p(w|z), size: K x V  
  151.         for (int z = 0; z < this.K; z++)  
  152.         {  
  153.             double norm = 0.0;  
  154.             for (int w = 0; w < this.V; w++)  
  155.             {  
  156.                 Pw_z[z][w] = Math.random();  
  157.                 norm += Pw_z[z][w];  
  158.             }  
  159.   
  160.             for (int w = 0; w < this.V; w++)  
  161.             {  
  162.                 Pw_z[z][w] /= norm;  
  163.             }  
  164.         }  
  165.   
  166.         // p(z|d,w), size: K x M x doc.size()  
  167.         for (int z = 0; z < this.K; z++)  
  168.         {  
  169.             for (int m = 0; m < this.M; m++)  
  170.             {  
  171.                 Pz_dw[z][m] = new double[this.dataset.getDataAt(m).size()];  
  172.             }  
  173.         }  
  174.         return false;  
  175.     }  
  176.   
  177.     private boolean Estep(double[] Pz, double[][] Pd_z, double[][] Pw_z, double[][][] Pz_dw)  
  178.     {  
  179.         for (int m = 0; m < this.M; m++)  
  180.         {  
  181.             Data data = this.dataset.getDataAt(m);  
  182.             for (int position = 0; position < data.size(); position++)  
  183.             {  
  184.                 // get word(dimension) at current position of document m  
  185.                 int w = data.getFeatureAt(position).dim;  
  186.   
  187.                 double norm = 0.0;  
  188.                 for (int z = 0; z < this.K; z++)  
  189.                 {  
  190.                     double val = Pz[z] * Pd_z[z][m] * Pw_z[z][w];  
  191.                     Pz_dw[z][m][position] = val;  
  192.                     norm += val;  
  193.                 }  
  194.   
  195.                 // normalization  
  196.                 for (int z = 0; z < this.K; z++)  
  197.                 {  
  198.                     Pz_dw[z][m][position] /= norm;  
  199.                 }  
  200.             }  
  201.         }  
  202.         return true;  
  203.     }  
  204.   
  205.     private boolean Mstep(double[][][] Pz_dw, double[][] Pw_z, double[][] Pd_z, double[] Pz)  
  206.     {  
  207.         // p(w|z)  
  208.         for (int z = 0; z < this.K; z++)  
  209.         {  
  210.             double norm = 0.0;  
  211.             for (int w = 0; w < this.V; w++)  
  212.             {  
  213.                 double sum = 0.0;  
  214.   
  215.                 Posting[] postings = this.invertedIndex[w];  
  216.                 for (Posting posting : postings)  
  217.                 {  
  218.                     int m = posting.docID;  
  219.                     int position = posting.pos;  
  220.                     double n = this.dataset.getDataAt(m).getFeatureAt(position).weight;  
  221.                     sum += n * Pz_dw[z][m][position];  
  222.                 }  
  223.                 Pw_z[z][w] = sum;  
  224.                 norm += sum;  
  225.             }  
  226.   
  227.             // normalization  
  228.             for (int w = 0; w < this.V; w++)  
  229.             {  
  230.                 Pw_z[z][w] /= norm;  
  231.             }  
  232.         }  
  233.   
  234.         // p(d|z)  
  235.         for (int z = 0; z < this.K; z++)  
  236.         {  
  237.             double norm = 0.0;  
  238.             for (int m = 0; m < this.M; m++)  
  239.             {  
  240.                 double sum = 0.0;  
  241.                 Data d = this.dataset.getDataAt(m);  
  242.                 for (int position = 0; position < d.size(); position++)  
  243.                 {  
  244.                     double n = d.getFeatureAt(position).weight;  
  245.                     sum += n * Pz_dw[z][m][position];  
  246.                 }  
  247.                 Pd_z[z][m] = sum;  
  248.                 norm += sum;  
  249.             }  
  250.   
  251.             // normalization  
  252.             for (int m = 0; m < this.M; m++)  
  253.             {  
  254.                 Pd_z[z][m] /= norm;  
  255.             }  
  256.         }  
  257.   
  258.         //This is definitely a bug  
  259.         //p(z) values are even, but they should not be even  
  260.         double norm = 0.0;  
  261.         for (int z = 0; z < this.K; z++)  
  262.         {  
  263.             double sum = 0.0;  
  264.             for (int m = 0; m < this.M; m++)  
  265.             {  
  266.                 sum += Pd_z[z][m];  
  267.             }  
  268.             Pz[z] = sum;  
  269.             norm += sum;  
  270.        }  
  271.   
  272.         // normalization  
  273.         for (int z = 0; z < this.K; z++)  
  274.         {  
  275.             Pz[z] /= norm;  
  276.             //System.out.format("%10.4f", Pz[z]);  //here you can print to see  
  277.         }  
  278.         //System.out.println();  
  279.   
  280.         return true;  
  281.     }  
  282.   
  283.     private double calcLoglikelihood(double[] Pz, double[][] Pd_z, double[][] Pw_z)  
  284.     {  
  285.         double L = 0.0;  
  286.         for (int m = 0; m < this.M; m++)  
  287.         {  
  288.             Data d = this.dataset.getDataAt(m);  
  289.             for (int position = 0; position < d.size(); position++)  
  290.             {  
  291.                 Feature f = d.getFeatureAt(position);  
  292.                 int w = f.dim;  
  293.                 double n = f.weight;  
  294.   
  295.                 double sum = 0.0;  
  296.                 for (int z = 0; z < this.K; z++)  
  297.                 {  
  298.                     sum += Pz[z] * Pd_z[z][m] * Pw_z[z][w];  
  299.                 }  
  300.                 L += n * Math.log10(sum);  
  301.             }  
  302.         }  
  303.         return L;  
  304.     }  
  305. }  
  306.   
  307. public class PLSA {  
  308.     public static void main(String[] args) {  
  309.           
  310.         ProbabilisticLSA plsa = new ProbabilisticLSA();  
  311.         //the file is not used, the hard coded data is used instead, but file name should be valid,  
  312.         //just replace the name by something valid.  
  313.         plsa.doPLSA("C:\\Users\\APolar\\workspace\\PLSA\\src\\data.txt"260);  
  314.         System.out.println("end PLSA");  
  315.     }  
  316. }  


4 Estimate parameters in a simple mixture unigram language model by EM

在PLSA的参数估计中,我们使用了EM算法。EM算法经常用来估计包含“缺失数据”或者“隐含变量”模型的参数估计问题。这两个概念是互相联系的,当我们的模型中有“隐含变量”时,我们会认为原始数据是“不完全的数据”,因为隐含变量的值无法观察到;反过来,当我们的数据incomplete时,我们可以通过增加隐含变量来对“缺失数据”建模。


为了加深对EM算法的理解,下面我们来看如何用EM算法来估计一个简单混合unigram语言模型的参数。本部分主要参考Zhai老师的EM算法Notes


4.1 最大似然估计与隐含变量引入

所谓unigram语言模型,就是构建语言模型是抛弃所有上下文信息,认为一个词出现的概率与其所在位置无关,具体概率图模型可以参见LDA及Gibbs Samping一文中的介绍。什么是混合模型(mixture model)呢?通俗的说混合概率模型就是由最基本的概率分布比如正态分布、多元分布等经过线性组合形成的新的概率模型,比如混合高斯模型就是由K个高斯分布线性组合而得到。混合模型中产生数据的确切“component model”对我们是隐藏的。我们假设混合模型包含两个multinomial component model,一个是背景词生成模型,另一个是主题词生成模型。注意这种模型组成方式在概率语言模型中很常见。为了表示单词是哪个模型生成的,我们会为每个单词增加一个布尔类型的控制变量。


文档的对数似然函数为



为第i个文档中的第j个词,为表示文档中背景词比例的参数,通常根据经验给定。因此是已知的,我们只需要估计即可。

同样的我们首先试图用最大似然估计来估计参数。也就是去找最大化似然函数的参数值,有




这是一个关于的函数,同样的,包含在了对数和中。因此很难求解极大值,用拉格朗日乘数法,你会发现偏导数等于0得到的方程很难求解。所以我们需要依赖数值算法,而EM算法就是其中常用的一种。


我们为每个单词引入一个布尔类型的变量z表示该单词是background word 还是topic word.即




这里我们假设"complete data"不仅包含可以观察到F中的所有单词,而且还包括隐含的变量z。那么根据EM算法,在E步骤我们计算“complete data”的对数似然函数有




比较一下,求和运算在对数之外进行,因为此时通过控制变量z的设置,我们明确知道了单词是由背景词分布还是topic 词分布产生的。的关系是怎样的呢?如果带估计参数是,原始数据是X,对于每一个原始数据分配了一个隐含变量H,则有




4.2 似然函数的下界分析

EM算法的基本思想就是初始随机给定待估计参数的值,然后通过E步骤和M步骤两步迭代去不断搜索更好的参数值。更好的参数值应该要满足使得似然函数更大。我们假设一个潜在的更好参数值是,第n次迭代M步骤得到的参数估计值是,那么两个参数值对应的似然函数和"complete data"的似然函数的差满足




我们寻找更好参数值的目标就是要最大化,也等价于最大化。我们来计算隐含变量在给定当前数据X和当前估计的参数值条件下的条件概率分布即,有




其中右边第三项是的相对熵,总为非负值。因此我们有




于是我们得到了潜在更好参数值的incomplete data似然函数的下界。这里我们尤其要注意右边后两项为常数,因为不包含。所以incomplete data似然函数的下界就是complete data似然函数的期望,也就是诸多EM算法讲义中出现的Q函数,表达式为




可以看出这个期望等于complete data似然函数乘以对应隐含变量条件概率再求和。对于我们要求解的问题,Q函数就是




这里多解释几句Q函数。单词相应的变量z为0时,单词为topic word,从多元分布中产生;当z为1时,单词为background word,从多元分布产生。同时我们也可以看到如何求Q函数即complete data似然函数的期望,也就是我们要最大化的那个期望(EM算法最大化期望指的就是这个期望),我们要特别关注隐含变量在观察到数据X和前一轮估计出的参数值条件下取不同值的概率,而隐含变量不同的值对应complete data的不同的似然函数,我们要计算的所谓的期望就是指complete data的似然函数值在不同隐含变量取值情况下的期望值。


4.3 EM算法的一般步骤

通过4.2部分的分析,我们知道,如果我们在下一轮迭代中可以找到一个更好的参数值使得




那么相应的也会有,因此EM算法的一般步骤如下

(1) 随机初始化参数值,也可以根据任何关于最佳参数取值范围的先验知识来初始化

(2) 不断两步迭代寻找更优的参数值

     (a) E步骤(求期望) 计算Q函数 




     (b)M步骤(最大化)通过最大化Q函数来寻找更优的参数值




(3) 当似然函数收敛时算法停止。


这里需要注意如何尽量保证EM算法可以找到全局最优解而不是局部最优解呢?第一种方法是尝试许多不同的参数初始值,然后从得到的很多估计出的参数值中选取最优的;第二种方法是通过一个更简单的模型比如只有唯一全局最大值的模型来决定复杂模型的初始值。


通过前面的分析可以知道,EM算法的优势在于complete data的似然函数更容易最大化,因为已经假定了隐含变量的取值,当然要乘以隐含变量取该值的条件概率,所以最终变成了最大化期望值。由于隐含变量变成了已知量,Q函数比原始incomplete data的似然函数更容易求最大值。因此对于“缺失数据”的情况,我们通过引入隐含变量使得complete data的似然函数容易最大化。


在E步骤中,主要的计算难点在于计算隐含变量的条件概率,在PLSA中就是




在我们这个简单混合语言模型的例子中就是




我们假设z的取值只于当前那一个单词有关,计算很容易,但是在LDA中用这种方法计算隐含变量的条件概率和最大化Q函数就比较复杂,可以参见原始LDA论文的参数推导部分。我们也可以用更简单的Gibbs Sampling来估计参数,具体可以参见LDA及Gibbs Samping


继续我们的问题,下面便是M步骤。使用拉格朗日乘数法来求Q函数的最大值,约束条件是




构造拉格朗日辅助函数




对自变量求偏导数




令偏导数为0解出来唯一的极值点




容易知道这里唯一的极值点就是最值点了。注意这里Zhai老师变换了一下变量表示,把对文档里面词的遍历转化成了对词典里面的term的遍历,因为z的取值至于对应的那一个单词有关,与上下文无关。因此E步骤求隐含变量的条件概率公式也相应变成了




最后我们就得到了简单混合Unigram语言模型的EM算法更新公式

即E步骤 求隐含变量条件概率和M步骤 最大化期望估计参数的公式




整个计算过程我们可以看到,我们不需要明确求出Q函数的表达式。取而代之的是我们计算隐含变量的条件概率,然后通过最大化Q函数来得到新的参数估计值。

因此EM算法两步迭代的过程实质是在寻找更好的待估计参数的值使得原始数据即incomplete data似然函数的下界不断提升,而这个“下界“就是引入隐含变量之后的complete data似然函数的期望,也就是诸多EM算法讲义中出现的Q函数,通过最大化Q函数来寻找更优的参数值。同时,上一轮估计出的参数值会在下一轮E步骤中当成已知条件计算隐含变量的条件概率,而这个条件概率又是最大化Q函数求新的参数值是所必需的。


5 Estimate parameters in GMM by EM

经过第3部分和第4部分用EM算法求解PLSA和简单unigram混合模型参数估计问题的详细分析,相信大部分读者已经对EM算法有了一定理解。关于EM算法的材料包括PRML会首先介绍用EM算法去求解混合高斯模型GMM的参数估计问题。下面就让我们来看看如何用EM算法来求解混合高斯模型GMM。


混合高斯模型GMM由K个高斯模型的线性组合组成,高斯模型就是正态分布模型,其中每个高斯模型我们成为一个”Component“,GMM的概率密度函数就是这K个高斯模型概率密度函数的线性组合即



其中

就是高斯分布即正态分布的概率密度函数。这是x为向量的情况,对于x为标量的情况就是


大部分读者应该对标量情形的概率分布更熟悉。这里啰嗦几句,最近看机器学习的论文和书籍,里面的随机变量基本都是多维向量,向量的计算比如加减乘除和求导运算都和标量运算有一些区别,尤其是求导运算,向量和矩阵的求导运算会麻烦很多,看pluskid推荐的一本册子Matrix Cookbook,里面有很多矩阵求导公式,直接查阅应该会更方便。


下面继续说GMM。根据上面给出的概率密度函数,如果我们要从 GMM 的分布中Sample一个样本,实际上可以分为两步:首先随机地在这 K 个 Component 之中选一个,每个 Component 被选中的概率实际上就是它的系数 \pi_k ,选中了 Component 之后,再单独地考虑从这个 Component 的分布中选取一个样本点就可以了。在PRML上,引入了一个K维二值随机变量z,只有1维是1,其他维都是0。唯一那个非零的维对应的就是GMM参数样本时被选中的那个高斯分布,而某一维非零的概率就是\pi_k,即



下面我们开始估计GMM的参数,包括这K个高斯分布的所有均值和方差以及线性组合的系数。我们给每个样本数据增加一个隐含变量, 就是上面所说的K维向量,表明了是从哪个高斯分布中sample出来的。对应的概率图模型就是




观察变量的对数似然函数为




令对的偏导数等于0我们有




注意这里我们定义了表示后验概率,也就是第n个样本是有第k个高斯分布产生的概率。可以解出



就是由第K个高斯分布产生的样本点的总数;用聚类的观点看,就是聚到cluster k的样本点总数。然后我们将对数似然函数对求偏导数,令偏导数为0,得到协方差矩阵




最后我们求系数\pi_k。注意到系数的和为1,即




这就是约束条件,最大化对数似然函数又成为了条件极值问题。我们仍然用拉格朗日乘数法,构造辅助函数如下




求导数,令导数为0有



这样我们就估计出来系数项。

因此用EM算法估计GMM参数的步骤如下


(1) E步骤:估计数据由每个 Component 生成的概率:对于每个数据  来说,它由第 k 个 Component 生成的概率为


注意里面 \mu_k 和 \Sigma_k 也是需要我们估计的值,在E步骤我们假定 \mu_k 和 \Sigma_k 均已知,我们使用上一次迭代所得的值(或者初始值)。


(2)M步骤:由最大估计求出高斯分布的所有均值、方差和线性组合的系数,更新待估计的参数值,根据上面的推导,计算公式是



其中


(3)重复迭代E步骤和M步骤,直到似然函数


收敛时算法停止。

更多关于EM算法的深入分析,可以参考PRML第9章内容。

最后我们给出用EM算法估计GMM参数的Matlab实现,出自pluskid的博客

 

[plain]   view plain copy
  1. function varargout = gmm(X, K_or_centroids)  
  2. % ============================================================  
  3. % Expectation-Maximization iteration implementation of  
  4. % Gaussian Mixture Model.  
  5. %  
  6. % PX = GMM(X, K_OR_CENTROIDS)  
  7. % [PX MODEL] = GMM(X, K_OR_CENTROIDS)  
  8. %  
  9. %  - X: N-by-D data matrix.  
  10. %  - K_OR_CENTROIDS: either K indicating the number of  
  11. %       components or a K-by-D matrix indicating the  
  12. %       choosing of the initial K centroids.  
  13. %  
  14. %  - PX: N-by-K matrix indicating the probability of each  
  15. %       component generating each point.  
  16. %  - MODEL: a structure containing the parameters for a GMM:  
  17. %       MODEL.Miu: a K-by-D matrix.  
  18. %       MODEL.Sigma: a D-by-D-by-K matrix.  
  19. %       MODEL.Pi: a 1-by-K vector.  
  20. % ============================================================  
  21.    
  22.     threshold = 1e-15;  
  23.     [N, D] = size(X);  
  24.    
  25.     if isscalar(K_or_centroids)  
  26.         K = K_or_centroids;  
  27.         % randomly pick centroids  
  28.         rndp = randperm(N);  
  29.         centroids = X(rndp(1:K), :);  
  30.     else  
  31.         K = size(K_or_centroids, 1);  
  32.         centroids = K_or_centroids;  
  33.     end  
  34.    
  35.     % initial values  
  36.     [pMiu pPi pSigma] = init_params();  
  37.    
  38.     Lprev = -inf;  
  39.     while true  
  40.         Px = calc_prob();  
  41.    
  42.         % new value for pGamma  
  43.         pGamma = Px .* repmat(pPi, N, 1);  
  44.         pGamma = pGamma ./ repmat(sum(pGamma, 2), 1, K);  
  45.    
  46.         % new value for parameters of each Component  
  47.         Nk = sum(pGamma, 1);  
  48.         pMiu = diag(1./Nk) * pGamma' * X;  
  49.         pPi = Nk/N;  
  50.         for kk = 1:K  
  51.             Xshift = X-repmat(pMiu(kk, :), N, 1);  
  52.             pSigma(:, :, kk) = (Xshift' * ...  
  53.                 (diag(pGamma(:, kk)) * Xshift)) / Nk(kk);  
  54.         end  
  55.    
  56.         % check for convergence  
  57.         L = sum(log(Px*pPi'));  
  58.         if L-Lprev < threshold  
  59.             break;  
  60.         end  
  61.         Lprev = L;  
  62.     end  
  63.    
  64.     if nargout == 1  
  65.         varargout = {Px};  
  66.     else  
  67.         model = [];  
  68.         model.Miu = pMiu;  
  69.         model.Sigma = pSigma;  
  70.         model.Pi = pPi;  
  71.         varargout = {Px, model};  
  72.     end  
  73.    
  74.     function [pMiu pPi pSigma] = init_params()  
  75.         pMiu = centroids;  
  76.         pPi = zeros(1, K);  
  77.         pSigma = zeros(D, D, K);  
  78.    
  79.         % hard assign x to each centroids  
  80.         distmat = repmat(sum(X.*X, 2), 1, K) + ...  
  81.             repmat(sum(pMiu.*pMiu, 2)', N, 1) - ...  
  82.             2*X*pMiu';  
  83.         [dummy labels] = min(distmat, [], 2);  
  84.    
  85.         for k=1:K  
  86.             Xk = X(labels == k, :);  
  87.             pPi(k) = size(Xk, 1)/N;  
  88.             pSigma(:, :, k) = cov(Xk);  
  89.         end  
  90.     end  
  91.    
  92.     function Px = calc_prob()  
  93.         Px = zeros(N, K);  
  94.         for k = 1:K  
  95.             Xshift = X-repmat(pMiu(k, :), N, 1);  
  96.             inv_pSigma = inv(pSigma(:, :, k));  
  97.             tmp = sum((Xshift*inv_pSigma) .* Xshift, 2);  
  98.             coef = (2*pi)^(-D/2) * sqrt(det(inv_pSigma));  
  99.             Px(:, k) = coef * exp(-0.5*tmp);  
  100.         end  
  101.     end  
  102. end  


 

6 全文总结

本文主要介绍PLSA及EM算法,首先给出LSA(隐性语义分析)的早期方法SVD,然后引入基于概率的PLSA模型,接着我们详细分析了如何用EM算法估计PLSA、混合unigram 语言模型及混合高斯模型的参数过程,并总结了EM算法的一般形式和运用关键点。关于EM算法收敛性的证明可以参考斯坦福机器学习课程CS229 Andrew Ng老师的课程notes和JerryLead的笔记。EM算法在”缺失数据“和包含”隐含变量“的概率模型参数估计问题中非常常用,是机器学习、数据挖掘及NLP研究必须掌握的算法。


 参考文献及推荐Notes

本文主要参考了Hoffman的PLSA论文、Zhai老师的EM Notes以及PRML第9章内容。

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.

[3] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[4] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[5] CX Zhai, A note on the expectation-maximization (em) algorithm 2007

[6] Qiaozhu Mei, A Note on EM Algorithm for Probabilistic Latent Semantic Analysis 2008

[7] pluskid, 漫谈Clustering, Gaussina Mixture Model

[8] Christopher D. ManningPrabhakar Raghavan and Hinrich Schütze, Introduction to Information Retrieval, Cambridge University Press. 2008.

[9] Tomas Hoffman, Unsupervised Learning by Probabilistic Latent Semantic Analysis. 2011

 

****************************************************************************************************************

第二篇 LDA及Gibbs Sampling

[本文PDF版本下载地址 LDA及Gibbs Sampling-yangliuy]

 1 LDA概要      

 LDA是由Blei,Ng, Jordan 2002年发表于JMLR的概率语言模型,应用到文本建模范畴,就是对文本进行“隐性语义分析”(LSA),目的是要以无指导学习的方法从文本中发现隐含的语义维度-即“Topic”或者“Concept”。隐性语义分析的实质是要利用文本中词项(term)的共现特征来发现文本的Topic结构,这种方法不需要任何关于文本的背景知识。文本的隐性语义表示可以对“一词多义”和“一义多词”的语言现象进行建模,这使得搜索引擎系统得到的搜索结果与用户的query在语义层次上match,而不是仅仅只是在词汇层次上出现交集。

 

2 概率基础

2.1 随机生成过程及共轭分布

     要理解LDA首先要理解随机生成过程。用随机生成过程的观点来看,文本是一系列服从一定概率分布的词项的样本集合。最常用的分布就是Multinomial分布,即多项分布,这个分布是二项分布拓展到K维的情况,比如投掷骰子实验,N次实验结果服从K=6的多项分布。相应的,二项分布的先验Beta分布也拓展到K维,称为Dirichlet分布。在概率语言模型中,通常为Multinomial分布选取的先验分布是Dirichlet分布,因为它们是共轭分布,可以带来计算上的方便性。什么是共轭分布呢?在文本语言模型的参数估计-最大似然估计、MAP及贝叶斯估计一文中我们可以看到,当我们为二项分布的参数p选取的先验分布是Beta分布时,以p为参数的二项分布用贝叶斯估计得到的后验概率仍然服从Beta分布,由此我们说二项分布和Beta分布是共轭分布。这就是共轭分布要满足的性质。在LDA中,每个文档中词的Topic分布服从Multinomial分布,其先验选取共轭先验即Dirichlet分布;每个Topic下词的分布服从Multinomial分布,其先验也同样选取共轭先验即Dirichlet分布。


 2.2 Multinomial分布和 Dirichlet分布

    上面从二项分布和Beta分布出发引出了Multinomial分布和Dirichlet分布。这两个分布在概率语言模型中很常用,让我们深入理解这两个分布。Multinomial分布的分布律如下



   多项分布来自N次独立重复实验,每次实验结果可能有K种,式子中为实验结果向量,N为实验次数,为出现每种实验结果的概率组成的向量,这个公式给出了出现所有实验结果的概率计算方法。当K=2时就是二项分布,K=6时就是投掷骰子实验。很好理解,前面的系数其实是枚举实验结果的不同出现顺序,即



后面表示第K种实验结果出现了次,所以是概率的相应次幂再求乘积。但是如果我们不考虑文本中词出现的顺序性,这个系数就是1。 本文后面的部分可以看出这一点。显然有各维之和为1,所有之和为N。

    Dirichlet分布可以看做是“分布之上的分布”,从Dirichlet分布上Draw出来的每个样本就是多项分布的参数向量。其分布律为




    为Dirichlet分布的参数,在概率语言模型中通常会根据经验给定,由于是参数向量服从分布的参数,因此称为“hyperparamer”。是Dirichlet delta函数,可以看做是Beta函数拓展到K维的情况,但是在有的文献中也直接写成。根据Dirichlet分布在上的积分为1(概率的基本性质),我们可以得到一个重要的公式




这个公式在后面LDA的参数Inference中经常使用。下图给出了一个Dirichlet分布的实例



在许多应用场合,我们使用对称Dirichlet分布,其参数是两个标量:维数K和参数向量各维均值. 其分布律如下



关于Dirichlet分布,维基百科上有一张很有意思的图如下

File:LogDirichletDensity-alpha 0.3 to alpha 2.0.gif

个图将Dirichlet分布的概率密度函数取对数

\log (f(x_1,\dots, x_{K-1}; \alpha_1,\dots, \alpha_K)) = \log\left(\frac{1}{\mathrm{B}(\alpha)} \prod_{i=1}^K x_i^{\alpha_i - 1}\right)=  + \sum_{i=1}^K \alpha_i \log(x_i) - \sum_{i=1}^K \log(x_i) - \sum_{i=1}^K  \log(\Gamma(\alpha_i)) + \log(\Gamma(\sum_{i=1}^K \alpha_i))

并且使用对称Dirichlet分布,取K=3,也就是有两个独立参数 x_1, x_2 ,分别对应图中的两个坐标轴,第三个参数始终满足x_3 = 1-x_1-x_2且 \alpha_1=\alpha_2=\alpha_3=\alpha ,图中反映的是\alpha从0.3变化到2.0的概率对数值的变化情况。


3 unigram model

我们先介绍比较简单的unigram model。其概率图模型图示如下




关于概率图模型尤其是贝叶斯网络的介绍可以参见 Stanford概率图模型(Probabilistic Graphical Model)— 第一讲 贝叶斯网络基础一文。简单的说,贝叶斯网络是一个有向无环图,图中的结点是随机变量,图中的有向边代表了随机变量的条件依赖关系。unigram model假设文本中的词服从Multinomial分布,而Multinomial分布的先验分布为Dirichlet分布。图中双线圆圈表示我们在文本中观察到的第n个词,表示文本中一共有N个词。加上方框表示重复,就是说一共有N个这样的随机变量是隐含未知变量,分别是词服从的Multinomial分布的参数和该Multinomial分布的先验Dirichlet分布的参数。一般由经验事先给定,由观察到的文本中出现的词学习得到,表示文本中出现每个词的概率。

 

4 LDA

 理解了unigram model之后,我们来看LDA。我们可以假想有一位大作家,比如莫言,他现在要写m篇文章,一共涉及了K个Topic,每个Topic下的词分布为一个从参数为的Dirichlet先验分布中sample出来的Multinomial分布(注意词典由term构成,每篇文章由word构成,前者不能重复,后者可以重复)。对于每篇文章,他首先会从一个泊松分布中sample一个值作为文章长度,再从一个参数为的Dirichlet先验分布中sample出一个Multinomial分布作为该文章里面出现每个Topic下词的概率;当他想写某篇文章中的第n个词的时候,首先从该文章中出现每个Topic的Multinomial分布中sample一个Topic,然后再在这个Topic对应的词的Multinomial分布中sample一个词作为他要写的词。不断重复这个随机生成过程,直到他把m篇文章全部写完。这就是LDA的一个形象通俗的解释。用数学的语言描述就是如下过程




转化成概率图模型表示就是




图中K为主题个数,M为文档总数,是第m个文档的单词总数。 是每个Topic下词的多项分布的Dirichlet先验参数,   是每个文档下Topic的多项分布的Dirichlet先验参数。是第m个文档中第n个词的主题,是m个文档中的第n个词。剩下来的两个隐含变量分别表示第m个文档下的Topic分布和第k个Topic下词的分布,前者是k维(k为Topic总数)向量,后者是v维向量(v为词典中term总数)。

    给定一个文档集合,是可以观察到的已知变量,是根据经验给定的先验参数,其他的变量都是未知的隐含变量,也是我们需要根据观察到的变量来学习估计的。根据LDA的图模型,我们可以写出所有变量的联合分布




那么一个词初始化为一个term t的概率是




也就是每个文档中出现topic k的概率乘以topic k下出现term t的概率,然后枚举所有topic求和得到。整个文档集合的似然函数就是




5 用Gibbs Sampling学习LDA

5.1   Gibbs Sampling的流程

 从第4部分的分析我们知道,LDA中的变量都是未知的隐含变量,也是我们需要根据观察到的文档集合中的词来学习估计的,那么如何来学习估计呢?这就是概率图模型的Inference问题。主要的算法分为exact inference和approximate inference两类。尽管LDA是最简单的Topic Model, 但是其用exact inference还是很困难的,一般我们采用approximate inference算法来学习LDA中的隐含变量。比如LDA原始论文Blei02中使用的mean-field variational expectation maximisation 算法和Griffiths02中使用的Gibbs Sampling,其中Gibbs Sampling 更为简单易懂。

    Gibbs Sampling 是Markov-Chain Monte Carlo算法的一个特例。这个算法的运行方式是每次选取概率向量的一个维度,给定其他维度的变量值Sample当前维度的值。不断迭代,直到收敛输出待估计的参数。可以图示如下



   初始时随机给文本中的每个单词分配主题,然后统计每个主题z下出现term t的数量以及每个文档m下出现主题z中的词的数量,每一轮计算,即排除当前词的主题分配,根据其他所有词的主题分配估计当前词分配各个主题的概率。当得到当前词属于所有主题z的概率分布后,根据这个概率分布为该词sample一个新的主题。然后用同样的方法不断更新下一个词的主题,直到发现每个文档下Topic分布和每个Topic下词的分布收敛,算法停止,输出待估计的参数,最终每个单词的主题也同时得出。实际应用中会设置最大迭代次数。每一次计算的公式称为Gibbs updating rule.下面我们来推导LDA的联合分布和Gibbs updating rule。


5.2   LDA的联合分布

由LDA的概率图模型,我们可以把LDA的联合分布写成




第一项和第二项因子分别可以写成



可以发现两个因子的展开形式很相似,第一项因子是给定主题Sample词的过程,可以拆分成从Dirichlet先验中SampleTopic Z下词的分布和从参数为的多元分布中Sample词这两个步骤,因此是Dirichlet分布和Multinomial分布的概率密度函数相乘,然后在上积分。注意这里用到的多元分布没有考虑词的顺序性,因此没有前面的系数项。表示term t被观察到分配topic z的次数,表示topic k分配给文档m中的word的次数.此为这里面还用到了2.2部分中导出的一个公式




因此这些积分都可以转化成Dirichlet delta函数,并不需要算积分。第二个因子是给定文档,sample当前词的主题的过程。由此LDA的联合分布就可以转化成全部由Dirichlet delta函数组成的表达式




这个式子在后面推导Gibbs updating rule时需要使用。


5.3   Gibbs updating rule

得到LDA的联合分布后,我们就可以推导Gibbs updating rule,即排除当前词的主题分配,根据其他词的主题分配和观察到的单词来计算当前词主题的概率公式




里面用到了伽马函数的性质


\Gamma(z+1)=z \, \Gamma(z).


同时需要注意到



这一项与当前词的主题分配无关,因为无论分配那个主题,对所有k求和的结果都是一样的,区别只在于拿掉的是哪个主题下的一个词。因此可以当成常量,最后我们只需要得到一个成正比的计算式来作为Gibbs updating rule即可。


5.4 Gibbs sampling algorithm

当Gibbs sampling 收敛后,我们需要根据最后文档集中所有单词的主题分配来计算,作为我们估计出来的概率图模型中的隐含变量。每个文档上Topic的后验分布和每个Topic下的term后验分布如下




可以看出这两个后验分布和对应的先验分布一样,仍然为Dirichlet分布,这也是共轭分布的性质决定的。

使用Dirichlet分布的期望计算公式



我们可以得到两个Multinomial分布的参数的计算公式如下



综上所述,用Gibbs Sampling 学习LDA参数的算法伪代码如下




关于这个算法的代码实现可以参见

* yangliuy's LDAGibbsSampling https://github.com/yangliuy/LDAGibbsSampling

Gregor Heinrich's LDA-J
Yee Whye Teh's Gibbs LDA Matlab codes
Mark Steyvers and Tom Griffiths's topic modeling matlab toolbox
GibbsLDA++


6 参考文献及推荐Notes

本文部分公式及图片来自 Parameter estimation for text analysis,感谢Gregor Heinrich详实细致的Technical report。 看过的一些关于LDA和Gibbs Sampling 的Notes, 这个是最准确细致的,内容最为全面系统。下面几个Notes对Topic Model感兴趣的朋友也推荐看一看。

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.
[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Wikipedia, Dirichlet distribution , http://en.wikipedia.org/wiki/Dirichlet_distribution

 

********************************************************************************************************************

 

第三章:LDA Gibbs Sampling的JAVA 实现

在本系列博文的前两篇,我们系统介绍了PLSA, LDA以及它们的参数Inference 方法,重点分析了模型表示和公式推导部分。曾有位学者说,“做研究要顶天立地”,意思是说做研究空有模型和理论还不够,我们还得有扎实的程序code和真实数据的实验结果来作为支撑。本文就重点分析 LDA Gibbs Sampling的JAVA 实现,并给出apply到newsgroup18828新闻文档集上得出的Topic建模结果。

本项目Github地址 https://github.com/yangliuy/LDAGibbsSampling


1、文档集预处理

要用LDA对文本进行topic建模,首先要对文本进行预处理,包括token,去停用词,stem,去noise词,去掉低频词等等。当语料库比较大时,我们也可以不进行stem。然后将文本转换成term的index表示形式,因为后面实现LDA的过程中经常需要在term和index之间进行映射。Documents类的实现如下,里面定义了Document内部类,用于描述文本集合中的文档。

 

[java]   view plain copy
  1. package liuyang.nlp.lda.main;  
  2.   
  3. import java.io.File;  
  4. import java.util.ArrayList;  
  5. import java.util.HashMap;  
  6. import java.util.Map;  
  7. import java.util.regex.Matcher;  
  8. import java.util.regex.Pattern;  
  9.   
  10. import liuyang.nlp.lda.com.FileUtil;  
  11. import liuyang.nlp.lda.com.Stopwords;  
  12.   
  13. /**Class for corpus which consists of M documents 
  14.  * @author yangliu 
  15.  * @blog http://blog.csdn.net/yangliuy 
  16.  * @mail yangliuyx@gmail.com 
  17.  */  
  18.   
  19. public class Documents {  
  20.       
  21.     ArrayList<Document> docs;   
  22.     Map<String, Integer> termToIndexMap;  
  23.     ArrayList<String> indexToTermMap;  
  24.     Map<String,Integer> termCountMap;  
  25.       
  26.     public Documents(){  
  27.         docs = new ArrayList<Document>();  
  28.         termToIndexMap = new HashMap<String, Integer>();  
  29.         indexToTermMap = new ArrayList<String>();  
  30.         termCountMap = new HashMap<String, Integer>();  
  31.     }  
  32.       
  33.     public void readDocs(String docsPath){  
  34.         for(File docFile : new File(docsPath).listFiles()){  
  35.             Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);  
  36.             docs.add(doc);  
  37.         }  
  38.     }  
  39.       
  40.     public static class Document {    
  41.         private String docName;  
  42.         int[] docWords;  
  43.           
  44.         public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){  
  45.             this.docName = docName;  
  46.             //Read file and initialize word index array  
  47.             ArrayList<String> docLines = new ArrayList<String>();  
  48.             ArrayList<String> words = new ArrayList<String>();  
  49.             FileUtil.readLines(docName, docLines);  
  50.             for(String line : docLines){  
  51.                 FileUtil.tokenizeAndLowerCase(line, words);  
  52.             }  
  53.             //Remove stop words and noise words  
  54.             for(int i = 0; i < words.size(); i++){  
  55.                 if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){  
  56.                     words.remove(i);  
  57.                     i--;  
  58.                 }  
  59.             }  
  60.             //Transfer word to index  
  61.             this.docWords = new int[words.size()];  
  62.             for(int i = 0; i < words.size(); i++){  
  63.                 String word = words.get(i);  
  64.                 if(!termToIndexMap.containsKey(word)){  
  65.                     int newIndex = termToIndexMap.size();  
  66.                     termToIndexMap.put(word, newIndex);  
  67.                     indexToTermMap.add(word);  
  68.                     termCountMap.put(word, new Integer(1));  
  69.                     docWords[i] = newIndex;  
  70.                 } else {  
  71.                     docWords[i] = termToIndexMap.get(word);  
  72.                     termCountMap.put(word, termCountMap.get(word) + 1);  
  73.                 }  
  74.             }  
  75.             words.clear();  
  76.         }  
  77.           
  78.         public boolean isNoiseWord(String string) {  
  79.             // TODO Auto-generated method stub  
  80.             string = string.toLowerCase().trim();  
  81.             Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");  
  82.             Matcher m = MY_PATTERN.matcher(string);  
  83.             // filter @xxx and URL  
  84.             if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||   
  85.                     string.matches(".*http:.*") )  
  86.                 return true;  
  87.             if (!m.matches()) {  
  88.                 return true;  
  89.             } else  
  90.                 return false;  
  91.         }  
  92.           
  93.     }  
  94. }  


2 LDA Gibbs Sampling

文本预处理完毕后我们就可以实现LDA Gibbs Sampling。 首先我们要定义需要的参数,我的实现中在程序中给出了参数默认值,同时也支持配置文件覆盖,程序默认优先选用配置文件的参数设置。整个算法流程包括模型初始化,迭代Inference,不断更新主题和待估计参数,最后输出收敛时的参数估计结果。

包含主函数的配置参数解析类如下:

 

[java]   view plain copy
  1. package liuyang.nlp.lda.main;  
  2.   
  3. import java.io.File;  
  4. import java.io.IOException;  
  5. import java.util.ArrayList;  
  6.   
  7. import liuyang.nlp.lda.com.FileUtil;  
  8. import liuyang.nlp.lda.conf.ConstantConfig;  
  9. import liuyang.nlp.lda.conf.PathConfig;  
  10.   
  11. /**Liu Yang's implementation of Gibbs Sampling of LDA 
  12.  * @author yangliu 
  13.  * @blog http://blog.csdn.net/yangliuy 
  14.  * @mail yangliuyx@gmail.com 
  15.  */  
  16.   
  17. public class LdaGibbsSampling {  
  18.       
  19.     public static class modelparameters {  
  20.         float alpha = 0.5f; //usual value is 50 / K  
  21.         float beta = 0.1f;//usual value is 0.1  
  22.         int topicNum = 100;  
  23.         int iteration = 100;  
  24.         int saveStep = 10;  
  25.         int beginSaveIters = 50;  
  26.     }  
  27.       
  28.     /**Get parameters from configuring file. If the  
  29.      * configuring file has value in it, use the value. 
  30.      * Else the default value in program will be used 
  31.      * @param ldaparameters 
  32.      * @param parameterFile 
  33.      * @return void 
  34.      */  
  35.     private static void getParametersFromFile(modelparameters ldaparameters,  
  36.             String parameterFile) {  
  37.         // TODO Auto-generated method stub  
  38.         ArrayList<String> paramLines = new ArrayList<String>();  
  39.         FileUtil.readLines(parameterFile, paramLines);  
  40.         for(String line : paramLines){  
  41.             String[] lineParts = line.split("\t");  
  42.             switch(parameters.valueOf(lineParts[0])){  
  43.             case alpha:  
  44.                 ldaparameters.alpha = Float.valueOf(lineParts[1]);  
  45.                 break;  
  46.             case beta:  
  47.                 ldaparameters.beta = Float.valueOf(lineParts[1]);  
  48.                 break;  
  49.             case topicNum:  
  50.                 ldaparameters.topicNum = Integer.valueOf(lineParts[1]);  
  51.                 break;  
  52.             case iteration:  
  53.                 ldaparameters.iteration = Integer.valueOf(lineParts[1]);  
  54.                 break;  
  55.             case saveStep:  
  56.                 ldaparameters.saveStep = Integer.valueOf(lineParts[1]);  
  57.                 break;  
  58.             case beginSaveIters:  
  59.                 ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);  
  60.                 break;  
  61.             }  
  62.         }  
  63.     }  
  64.       
  65.     public enum parameters{  
  66.         alpha, beta, topicNum, iteration, saveStep, beginSaveIters;  
  67.     }  
  68.       
  69.     /** 
  70.      * @param args 
  71.      * @throws IOException  
  72.      */  
  73.     public static void main(String[] args) throws IOException {  
  74.         // TODO Auto-generated method stub  
  75.         String originalDocsPath = PathConfig.ldaDocsPath;  
  76.         String resultPath = PathConfig.LdaResultsPath;  
  77.         String parameterFile= ConstantConfig.LDAPARAMETERFILE;  
  78.           
  79.         modelparameters ldaparameters = new modelparameters();  
  80.         getParametersFromFile(ldaparameters, parameterFile);  
  81.         Documents docSet = new Documents();  
  82.         docSet.readDocs(originalDocsPath);  
  83.         System.out.println("wordMap size " + docSet.termToIndexMap.size());  
  84.         FileUtil.mkdir(new File(resultPath));  
  85.         LdaModel model = new LdaModel(ldaparameters);  
  86.         System.out.println("1 Initialize the model ...");  
  87.         model.initializeModel(docSet);  
  88.         System.out.println("2 Learning and Saving the model ...");  
  89.         model.inferenceModel(docSet);  
  90.         System.out.println("3 Output the final model ...");  
  91.         model.saveIteratedModel(ldaparameters.iteration, docSet);  
  92.         System.out.println("Done!");  
  93.     }  
  94. }  


LDA 模型实现类如下

 

[java]   view plain copy
  1. package liuyang.nlp.lda.main;  
  2.   
  3. /**Class for Lda model 
  4.  * @author yangliu 
  5.  * @blog http://blog.csdn.net/yangliuy 
  6.  * @mail yangliuyx@gmail.com 
  7.  */  
  8. import java.io.BufferedWriter;  
  9. import java.io.FileWriter;  
  10. import java.io.IOException;  
  11. import java.util.ArrayList;  
  12. import java.util.Collections;  
  13. import java.util.Comparator;  
  14. import java.util.List;  
  15.   
  16. import liuyang.nlp.lda.com.FileUtil;  
  17. import liuyang.nlp.lda.conf.PathConfig;  
  18.   
  19. public class LdaModel {  
  20.       
  21.     int [][] doc;//word index array  
  22.     int V, K, M;//vocabulary size, topic number, document number  
  23.     int [][] z;//topic label array  
  24.     float alpha; //doc-topic dirichlet prior parameter   
  25.     float beta; //topic-word dirichlet prior parameter  
  26.     int [][] nmk;//given document m, count times of topic k. M*K  
  27.     int [][] nkt;//given topic k, count times of term t. K*V  
  28.     int [] nmkSum;//Sum for each row in nmk  
  29.     int [] nktSum;//Sum for each row in nkt  
  30.     double [][] phi;//Parameters for topic-word distribution K*V  
  31.     double [][] theta;//Parameters for doc-topic distribution M*K  
  32.     int iterations;//Times of iterations  
  33.     int saveStep;//The number of iterations between two saving  
  34.     int beginSaveIters;//Begin save model at this iteration  
  35.       
  36.     public LdaModel(LdaGibbsSampling.modelparameters modelparam) {  
  37.         // TODO Auto-generated constructor stub  
  38.         alpha = modelparam.alpha;  
  39.         beta = modelparam.beta;  
  40.         iterations = modelparam.iteration;  
  41.         K = modelparam.topicNum;  
  42.         saveStep = modelparam.saveStep;  
  43.         beginSaveIters = modelparam.beginSaveIters;  
  44.     }  
  45.   
  46.     public void initializeModel(Documents docSet) {  
  47.         // TODO Auto-generated method stub  
  48.         M = docSet.docs.size();  
  49.         V = docSet.termToIndexMap.size();  
  50.         nmk = new int [M][K];  
  51.         nkt = new int[K][V];  
  52.         nmkSum = new int[M];  
  53.         nktSum = new int[K];  
  54.         phi = new double[K][V];  
  55.         theta = new double[M][K];  
  56.           
  57.         //initialize documents index array  
  58.         doc = new int[M][];  
  59.         for(int m = 0; m < M; m++){  
  60.             //Notice the limit of memory  
  61.             int N = docSet.docs.get(m).docWords.length;  
  62.             doc[m] = new int[N];  
  63.             for(int n = 0; n < N; n++){  
  64.                 doc[m][n] = docSet.docs.get(m).docWords[n];  
  65.             }  
  66.         }  
  67.           
  68.         //initialize topic lable z for each word  
  69.         z = new int[M][];  
  70.         for(int m = 0; m < M; m++){  
  71.             int N = docSet.docs.get(m).docWords.length;  
  72.             z[m] = new int[N];  
  73.             for(int n = 0; n < N; n++){  
  74.                 int initTopic = (int)(Math.random() * K);// From 0 to K - 1  
  75.                 z[m][n] = initTopic;  
  76.                 //number of words in doc m assigned to topic initTopic add 1  
  77.                 nmk[m][initTopic]++;  
  78.                 //number of terms doc[m][n] assigned to topic initTopic add 1  
  79.                 nkt[initTopic][doc[m][n]]++;  
  80.                 // total number of words assigned to topic initTopic add 1  
  81.                 nktSum[initTopic]++;  
  82.             }  
  83.              // total number of words in document m is N  
  84.             nmkSum[m] = N;  
  85.         }  
  86.     }  
  87.   
  88.     public void inferenceModel(Documents docSet) throws IOException {  
  89.         // TODO Auto-generated method stub  
  90.         if(iterations < saveStep + beginSaveIters){  
  91.             System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));  
  92.             System.exit(0);  
  93.         }  
  94.         for(int i = 0; i < iterations; i++){  
  95.             System.out.println("Iteration " + i);  
  96.             if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){  
  97.                 //Saving the model  
  98.                 System.out.println("Saving model at iteration " + i +" ... ");  
  99.                 //Firstly update parameters  
  100.                 updateEstimatedParameters();  
  101.                 //Secondly print model variables  
  102.                 saveIteratedModel(i, docSet);  
  103.             }  
  104.               
  105.             //Use Gibbs Sampling to update z[][]  
  106.             for(int m = 0; m < M; m++){  
  107.                 int N = docSet.docs.get(m).docWords.length;  
  108.                 for(int n = 0; n < N; n++){  
  109.                     // Sample from p(z_i|z_-i, w)  
  110.                     int newTopic = sampleTopicZ(m, n);  
  111.                     z[m][n] = newTopic;  
  112.                 }  
  113.             }  
  114.         }  
  115.     }  
  116.       
  117.     private void updateEstimatedParameters() {  
  118.         // TODO Auto-generated method stub  
  119.         for(int k = 0; k < K; k++){  
  120.             for(int t = 0; t < V; t++){  
  121.                 phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);  
  122.             }  
  123.         }  
  124.           
  125.         for(int m = 0; m < M; m++){  
  126.             for(int k = 0; k < K; k++){  
  127.                 theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);  
  128.             }  
  129.         }  
  130.     }  
  131.   
  132.     private int sampleTopicZ(int m, int n) {  
  133.         // TODO Auto-generated method stub  
  134.         // Sample from p(z_i|z_-i, w) using Gibbs upde rule  
  135.           
  136.         //Remove topic label for w_{m,n}  
  137.         int oldTopic = z[m][n];  
  138.         nmk[m][oldTopic]--;  
  139.         nkt[oldTopic][doc[m][n]]--;  
  140.         nmkSum[m]--;  
  141.         nktSum[oldTopic]--;  
  142.           
  143.         //Compute p(z_i = k|z_-i, w)  
  144.         double [] p = new double[K];  
  145.         for(int k = 0; k < K; k++){  
  146.             p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);  
  147.         }  
  148.           
  149.         //Sample a new topic label for w_{m, n} like roulette  
  150.         //Compute cumulated probability for p  
  151.         for(int k = 1; k < K; k++){  
  152.             p[k] += p[k - 1];  
  153.         }  
  154.         double u = Math.random() * p[K - 1]; //p[] is unnormalised  
  155.         int newTopic;  
  156.         for(newTopic = 0; newTopic < K; newTopic++){  
  157.             if(u < p[newTopic]){  
  158.                 break;  
  159.             }  
  160.         }  
  161.           
  162.         //Add new topic label for w_{m, n}  
  163.         nmk[m][newTopic]++;  
  164.         nkt[newTopic][doc[m][n]]++;  
  165.         nmkSum[m]++;  
  166.         nktSum[newTopic]++;  
  167.         return newTopic;  
  168.     }  
  169.   
  170.     public void saveIteratedModel(int iters, Documents docSet) throws IOException {  
  171.         // TODO Auto-generated method stub  
  172.         //lda.params lda.phi lda.theta lda.tassign lda.twords  
  173.         //lda.params  
  174.         String resPath = PathConfig.LdaResultsPath;  
  175.         String modelName = "lda_" + iters;  
  176.         ArrayList<String> lines = new ArrayList<String>();  
  177.         lines.add("alpha = " + alpha);  
  178.         lines.add("beta = " + beta);  
  179.         lines.add("topicNum = " + K);  
  180.         lines.add("docNum = " + M);  
  181.         lines.add("termNum = " + V);  
  182.         lines.add("iterations = " + iterations);  
  183.         lines.add("saveStep = " + saveStep);  
  184.         lines.add("beginSaveIters = " + beginSaveIters);  
  185.         FileUtil.writeLines(resPath + modelName + ".params", lines);  
  186.           
  187.         //lda.phi K*V  
  188.         BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));         
  189.         for (int i = 0; i < K; i++){  
  190.             for (int j = 0; j < V; j++){  
  191.                 writer.write(phi[i][j] + "\t");  
  192.             }  
  193.             writer.write("\n");  
  194.         }  
  195.         writer.close();  
  196.           
  197.         //lda.theta M*K  
  198.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));  
  199.         for(int i = 0; i < M; i++){  
  200.             for(int j = 0; j < K; j++){  
  201.                 writer.write(theta[i][j] + "\t");  
  202.             }  
  203.             writer.write("\n");  
  204.         }  
  205.         writer.close();  
  206.           
  207.         //lda.tassign  
  208.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));  
  209.         for(int m = 0; m < M; m++){  
  210.             for(int n = 0; n < doc[m].length; n++){  
  211.                 writer.write(doc[m][n] + ":" + z[m][n] + "\t");  
  212.             }  
  213.             writer.write("\n");  
  214.         }  
  215.         writer.close();  
  216.           
  217.         //lda.twords phi[][] K*V  
  218.         writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));  
  219.         int topNum = 20//Find the top 20 topic words in each topic  
  220.         for(int i = 0; i < K; i++){  
  221.             List<Integer> tWordsIndexArray = new ArrayList<Integer>();   
  222.             for(int j = 0; j < V; j++){  
  223.                 tWordsIndexArray.add(new Integer(j));  
  224.             }  
  225.             Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));  
  226.             writer.write("topic " + i + "\t:\t");  
  227.             for(int t = 0; t < topNum; t++){  
  228.                 writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");  
  229.             }  
  230.             writer.write("\n");  
  231.         }  
  232.         writer.close();  
  233.     }  
  234.       
  235.     public class TwordsComparable implements Comparator<Integer> {  
  236.           
  237.         public double [] sortProb; // Store probability of each word in topic k  
  238.           
  239.         public TwordsComparable (double[] sortProb){  
  240.             this.sortProb = sortProb;  
  241.         }  
  242.   
  243.         @Override  
  244.         public int compare(Integer o1, Integer o2) {  
  245.             // TODO Auto-generated method stub  
  246.             //Sort topic word index according to the probability of each word in topic k  
  247.             if(sortProb[o1] > sortProb[o2]) return -1;  
  248.             else if(sortProb[o1] < sortProb[o2]) return 1;  
  249.             else return 0;  
  250.         }  
  251.     }  
  252. }  


程序的实现细节可以参考我在程序中给出的注释,如果理解LDA Gibbs Sampling的算法流程,上面的代码很好理解。其实排除输入输出和参数解析的代码,标准LDA 的Gibbs sampling只需要不到200行程序就可以搞定。当然,里面有很多可以考虑优化和变形的地方。

还有com和conf目录下的源文件分别放置常用函数和配置类,完整的JAVA工程见Github https://github.com/yangliuy/LDAGibbsSampling


3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析

下面我们给出将上面的LDA Gibbs Sampling的实现Apply到Newsgroup 18828文档集进行主题分析的结果。 我实验时用到的数据已经上传到Github中,感兴趣的朋友可以直接从Github中下载工程运行。 我在Newsgroup 18828文档集随机选择了9个目录,每个目录下选择一个文档,将它们放置在data\LdaOriginalDocs目录下,我设定的模型参数如下

 

[plain]   view plain copy
  1. alpha   0.5  
  2. beta    0.1  
  3. topicNum    10  
  4. iteration   100  
  5. saveStep    10  
  6. beginSaveIters  80  


即设定alpha和beta的值为0.5和0.1, Topic数目为10,迭代100次,从第80次开始保存模型结果,每10次保存一次。

经过100次Gibbs Sampling迭代后,程序输出10个Topic下top的topic words以及对应的概率值如下



我们可以看到虽然是unsupervised learning, LDA分析出来的Topic words还是非常make sense的。比如第5个topic是宗教类的,第6个topic是天文类的,第7个topic是计算机类的。程序的输出还包括模型参数.param文件,topic-word分布phi向量.phi文件,doc-topic分布theta向量.theta文件以及每个文档中每个单词分配到的主题label的.tassign文件。感兴趣的朋友可以从Github https://github.com/yangliuy/LDAGibbsSampling 下载完整工程自己换用其他数据集进行主题分析实验。 本程序是初步实现版本,如果大家发现任何问题或者bug欢迎交流,我第一时间在Github修复bug更新版本。


4 参考文献

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.
[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Jgibblda, http://jgibblda.sourceforge.net/

[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3 (March 2003), 993-1022.

 

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值