概率语言模型及其变形系列(5)-LDA Gibbs Sampling 的JAVA实现

1、文档集预处理

package liuyang.nlp.lda.main;

import java.io.File;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import liuyang.nlp.lda.com.FileUtil;
import liuyang.nlp.lda.com.Stopwords;

/**Class for corpus which consists of M documents
* @author yangliu
* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com
*/

public class Documents {

ArrayList<Document> docs;
Map<String, Integer> termToIndexMap;
ArrayList<String> indexToTermMap;
Map<String,Integer> termCountMap;

public Documents(){
docs = new ArrayList<Document>();
termToIndexMap = new HashMap<String, Integer>();
indexToTermMap = new ArrayList<String>();
termCountMap = new HashMap<String, Integer>();
}

for(File docFile : new File(docsPath).listFiles()){
Document doc = new Document(docFile.getAbsolutePath(), termToIndexMap, indexToTermMap, termCountMap);
}
}

public static class Document {
private String docName;
int[] docWords;

public Document(String docName, Map<String, Integer> termToIndexMap, ArrayList<String> indexToTermMap, Map<String, Integer> termCountMap){
this.docName = docName;
//Read file and initialize word index array
ArrayList<String> docLines = new ArrayList<String>();
ArrayList<String> words = new ArrayList<String>();
for(String line : docLines){
FileUtil.tokenizeAndLowerCase(line, words);
}
//Remove stop words and noise words
for(int i = 0; i < words.size(); i++){
if(Stopwords.isStopword(words.get(i)) || isNoiseWord(words.get(i))){
words.remove(i);
i--;
}
}
//Transfer word to index
this.docWords = new int[words.size()];
for(int i = 0; i < words.size(); i++){
String word = words.get(i);
if(!termToIndexMap.containsKey(word)){
int newIndex = termToIndexMap.size();
termToIndexMap.put(word, newIndex);
termCountMap.put(word, new Integer(1));
docWords[i] = newIndex;
} else {
docWords[i] = termToIndexMap.get(word);
termCountMap.put(word, termCountMap.get(word) + 1);
}
}
words.clear();
}

public boolean isNoiseWord(String string) {
// TODO Auto-generated method stub
string = string.toLowerCase().trim();
Pattern MY_PATTERN = Pattern.compile(".*[a-zA-Z]+.*");
Matcher m = MY_PATTERN.matcher(string);
// filter @xxx and URL
if(string.matches(".*www\\..*") || string.matches(".*\\.com.*") ||
string.matches(".*http:.*") )
return true;
if (!m.matches()) {
return true;
} else
return false;
}

}
}


2 LDA Gibbs Sampling

package liuyang.nlp.lda.main;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;

import liuyang.nlp.lda.com.FileUtil;
import liuyang.nlp.lda.conf.ConstantConfig;
import liuyang.nlp.lda.conf.PathConfig;

/**Liu Yang's implementation of Gibbs Sampling of LDA
* @author yangliu
* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com
*/

public class LdaGibbsSampling {

public static class modelparameters {
float alpha = 0.5f; //usual value is 50 / K
float beta = 0.1f;//usual value is 0.1
int topicNum = 100;
int iteration = 100;
int saveStep = 10;
int beginSaveIters = 50;
}

/**Get parameters from configuring file. If the
* configuring file has value in it, use the value.
* Else the default value in program will be used
* @param ldaparameters
* @param parameterFile
* @return void
*/
private static void getParametersFromFile(modelparameters ldaparameters,
String parameterFile) {
// TODO Auto-generated method stub
ArrayList<String> paramLines = new ArrayList<String>();
for(String line : paramLines){
String[] lineParts = line.split("\t");
switch(parameters.valueOf(lineParts[0])){
case alpha:
ldaparameters.alpha = Float.valueOf(lineParts[1]);
break;
case beta:
ldaparameters.beta = Float.valueOf(lineParts[1]);
break;
case topicNum:
ldaparameters.topicNum = Integer.valueOf(lineParts[1]);
break;
case iteration:
ldaparameters.iteration = Integer.valueOf(lineParts[1]);
break;
case saveStep:
ldaparameters.saveStep = Integer.valueOf(lineParts[1]);
break;
case beginSaveIters:
ldaparameters.beginSaveIters = Integer.valueOf(lineParts[1]);
break;
}
}
}

public enum parameters{
alpha, beta, topicNum, iteration, saveStep, beginSaveIters;
}

/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
String resultPath = PathConfig.LdaResultsPath;
String parameterFile= ConstantConfig.LDAPARAMETERFILE;

modelparameters ldaparameters = new modelparameters();
getParametersFromFile(ldaparameters, parameterFile);
Documents docSet = new Documents();
System.out.println("wordMap size " + docSet.termToIndexMap.size());
FileUtil.mkdir(new File(resultPath));
LdaModel model = new LdaModel(ldaparameters);
System.out.println("1 Initialize the model ...");
model.initializeModel(docSet);
System.out.println("2 Learning and Saving the model ...");
model.inferenceModel(docSet);
System.out.println("3 Output the final model ...");
model.saveIteratedModel(ldaparameters.iteration, docSet);
System.out.println("Done!");
}
}


LDA 模型实现类如下

package liuyang.nlp.lda.main;

/**Class for Lda model
* @author yangliu
* @blog http://blog.csdn.net/yangliuy
* @mail yangliuyx@gmail.com
*/
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;

import liuyang.nlp.lda.com.FileUtil;
import liuyang.nlp.lda.conf.PathConfig;

public class LdaModel {

int [][] doc;//word index array
int V, K, M;//vocabulary size, topic number, document number
int [][] z;//topic label array
float alpha; //doc-topic dirichlet prior parameter
float beta; //topic-word dirichlet prior parameter
int [][] nmk;//given document m, count times of topic k. M*K
int [][] nkt;//given topic k, count times of term t. K*V
int [] nmkSum;//Sum for each row in nmk
int [] nktSum;//Sum for each row in nkt
double [][] phi;//Parameters for topic-word distribution K*V
double [][] theta;//Parameters for doc-topic distribution M*K
int iterations;//Times of iterations
int saveStep;//The number of iterations between two saving
int beginSaveIters;//Begin save model at this iteration

public LdaModel(LdaGibbsSampling.modelparameters modelparam) {
// TODO Auto-generated constructor stub
alpha = modelparam.alpha;
beta = modelparam.beta;
iterations = modelparam.iteration;
K = modelparam.topicNum;
saveStep = modelparam.saveStep;
beginSaveIters = modelparam.beginSaveIters;
}

public void initializeModel(Documents docSet) {
// TODO Auto-generated method stub
M = docSet.docs.size();
V = docSet.termToIndexMap.size();
nmk = new int [M][K];
nkt = new int[K][V];
nmkSum = new int[M];
nktSum = new int[K];
phi = new double[K][V];
theta = new double[M][K];

//initialize documents index array
doc = new int[M][];
for(int m = 0; m < M; m++){
//Notice the limit of memory
int N = docSet.docs.get(m).docWords.length;
doc[m] = new int[N];
for(int n = 0; n < N; n++){
doc[m][n] = docSet.docs.get(m).docWords[n];
}
}

//initialize topic lable z for each word
z = new int[M][];
for(int m = 0; m < M; m++){
int N = docSet.docs.get(m).docWords.length;
z[m] = new int[N];
for(int n = 0; n < N; n++){
int initTopic = (int)(Math.random() * K);// From 0 to K - 1
z[m][n] = initTopic;
//number of words in doc m assigned to topic initTopic add 1
nmk[m][initTopic]++;
//number of terms doc[m][n] assigned to topic initTopic add 1
nkt[initTopic][doc[m][n]]++;
// total number of words assigned to topic initTopic add 1
nktSum[initTopic]++;
}
// total number of words in document m is N
nmkSum[m] = N;
}
}

public void inferenceModel(Documents docSet) throws IOException {
// TODO Auto-generated method stub
if(iterations < saveStep + beginSaveIters){
System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));
System.exit(0);
}
for(int i = 0; i < iterations; i++){
System.out.println("Iteration " + i);
if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){
//Saving the model
System.out.println("Saving model at iteration " + i +" ... ");
//Firstly update parameters
updateEstimatedParameters();
//Secondly print model variables
saveIteratedModel(i, docSet);
}

//Use Gibbs Sampling to update z[][]
for(int m = 0; m < M; m++){
int N = docSet.docs.get(m).docWords.length;
for(int n = 0; n < N; n++){
// Sample from p(z_i|z_-i, w)
int newTopic = sampleTopicZ(m, n);
z[m][n] = newTopic;
}
}
}
}

private void updateEstimatedParameters() {
// TODO Auto-generated method stub
for(int k = 0; k < K; k++){
for(int t = 0; t < V; t++){
phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
}
}

for(int m = 0; m < M; m++){
for(int k = 0; k < K; k++){
theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
}
}
}

private int sampleTopicZ(int m, int n) {
// TODO Auto-generated method stub
// Sample from p(z_i|z_-i, w) using Gibbs upde rule

//Remove topic label for w_{m,n}
int oldTopic = z[m][n];
nmk[m][oldTopic]--;
nkt[oldTopic][doc[m][n]]--;
nmkSum[m]--;
nktSum[oldTopic]--;

//Compute p(z_i = k|z_-i, w)
double [] p = new double[K];
for(int k = 0; k < K; k++){
p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta) * (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
}

//Sample a new topic label for w_{m, n} like roulette
//Compute cumulated probability for p
for(int k = 1; k < K; k++){
p[k] += p[k - 1];
}
double u = Math.random() * p[K - 1]; //p[] is unnormalised
int newTopic;
for(newTopic = 0; newTopic < K; newTopic++){
if(u < p[newTopic]){
break;
}
}

//Add new topic label for w_{m, n}
nmk[m][newTopic]++;
nkt[newTopic][doc[m][n]]++;
nmkSum[m]++;
nktSum[newTopic]++;
return newTopic;
}

public void saveIteratedModel(int iters, Documents docSet) throws IOException {
// TODO Auto-generated method stub
//lda.params lda.phi lda.theta lda.tassign lda.twords
//lda.params
String resPath = PathConfig.LdaResultsPath;
String modelName = "lda_" + iters;
ArrayList<String> lines = new ArrayList<String>();
FileUtil.writeLines(resPath + modelName + ".params", lines);

//lda.phi K*V
BufferedWriter writer = new BufferedWriter(new FileWriter(resPath + modelName + ".phi"));
for (int i = 0; i < K; i++){
for (int j = 0; j < V; j++){
writer.write(phi[i][j] + "\t");
}
writer.write("\n");
}
writer.close();

//lda.theta M*K
writer = new BufferedWriter(new FileWriter(resPath + modelName + ".theta"));
for(int i = 0; i < M; i++){
for(int j = 0; j < K; j++){
writer.write(theta[i][j] + "\t");
}
writer.write("\n");
}
writer.close();

//lda.tassign
writer = new BufferedWriter(new FileWriter(resPath + modelName + ".tassign"));
for(int m = 0; m < M; m++){
for(int n = 0; n < doc[m].length; n++){
writer.write(doc[m][n] + ":" + z[m][n] + "\t");
}
writer.write("\n");
}
writer.close();

//lda.twords phi[][] K*V
writer = new BufferedWriter(new FileWriter(resPath + modelName + ".twords"));
int topNum = 20; //Find the top 20 topic words in each topic
for(int i = 0; i < K; i++){
List<Integer> tWordsIndexArray = new ArrayList<Integer>();
for(int j = 0; j < V; j++){
}
Collections.sort(tWordsIndexArray, new LdaModel.TwordsComparable(phi[i]));
writer.write("topic " + i + "\t:\t");
for(int t = 0; t < topNum; t++){
writer.write(docSet.indexToTermMap.get(tWordsIndexArray.get(t)) + " " + phi[i][tWordsIndexArray.get(t)] + "\t");
}
writer.write("\n");
}
writer.close();
}

public class TwordsComparable implements Comparator<Integer> {

public double [] sortProb; // Store probability of each word in topic k

public TwordsComparable (double[] sortProb){
this.sortProb = sortProb;
}

@Override
public int compare(Integer o1, Integer o2) {
// TODO Auto-generated method stub
//Sort topic word index according to the probability of each word in topic k
if(sortProb[o1] > sortProb[o2]) return -1;
else if(sortProb[o1] < sortProb[o2]) return 1;
else return 0;
}
}
}


3 用LDA Gibbs Sampling对Newsgroup 18828文档集进行主题分析

alpha	0.5
beta	0.1
topicNum	10
iteration	100
saveStep	10
beginSaveIters	80

4 参考文献

[1] Christopher M. Bishop. Pattern Recognition and Machine Learning (Information Science and Statistics). Springer-Verlag New York, Inc., Secaucus, NJ, USA, 2006.
[2] Gregor Heinrich. Parameter estimation for text analysis. Technical report, 2004.
[3] Wang Yi. Distributed Gibbs Sampling of Latent Topic Models: The Gritty Details Technical report, 2005.

[4] Wayne Xin Zhao, Note for pLSA and LDA, Technical report, 2011.

[5] Freddy Chong Tat Chua. Dimensionality reduction and clustering of text documents.Technical report, 2009.

[6] Jgibblda, http://jgibblda.sourceforge.net/

[7]David M. Blei, Andrew Y. Ng, and Michael I. Jordan. 2003. Latent dirichlet allocation. J. Mach. Learn. Res. 3 (March 2003), 993-1022.

• 广告
• 抄袭
• 版权
• 政治
• 色情
• 无意义
• 其他

120