关于LDA的介绍见前面几篇文章,这里是Gibbs抽样解LDA的实现
可以看到收敛之后主题的结果基本不变
package org.jazywoo.lda;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
public class Document {
private String docName;
private List<Integer> words; //词对应的termID
public Document(String docName) {
this.docName=docName;
}
public String getDocName() {
return docName;
}
public void setDocName(String docName) {
this.docName = docName;
}
public List<Integer> getWords() {
return words;
}
public void setWords(List<Integer> words) {
this.words = words;
}
}
package org.jazywoo.lda;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.StringTokenizer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.jazywoo.tokenization.Tokenization;
import ICTCLAS.I3S.AC.ICTCLAS50;
public class Corpus {
private List<Document> docs; //文档
private Map<String, Integer> termIndexMap;//词--序号
private List<String> terms;
private Map<String, Integer> termCountMap;//词频
public Corpus() {
docs = new ArrayList<Document>();
termIndexMap = new HashMap<String, Integer>();
terms = new ArrayList<String>();
termCountMap = new HashMap<String, Integer>();
}
public void loadData(String path) throws IOException{
File folder=new File(path);
if(folder.exists()){
File[] files=folder.listFiles();
for(File f:files){
BufferedReader br = new BufferedReader(new FileReader(f));
String line = "";
StringBuffer buf=new StringBuffer();
while ((line = br.readLine()) != null) {
buf.append(line+" ");
}
addDocument("doc", buf.toString());
}
}
}
private void addDocument(String docName, String content){
Document document=new Document(docName);
String[] words=getWordsFromSentence(content);
List<Integer> wordsList=new ArrayList<Integer>();
int termCount=0;
for(int i=0;i<words.length;++i){
String term=words[i];
if(termIndexMap.containsKey(term)){
termCountMap.put(term, termCountMap.get(term)+1);
}else{//不存在该词
int index=termIndexMap.size();
termIndexMap.put(term, index);
terms.add(term);
termCountMap.put(term, 0);
}
int termID=termIndexMap.get(term);
wordsList.add(termID);
}
document.setWords(wordsList);
docs.add(document);
}
/**从句子中得到分词,过滤掉停用词和干扰词
* @param content
* @return
*/
private String[] getWordsFromSentence(String content){
ICTCLAS50 ictclas=new ICTCLAS50();
Tokenization tokenization=new Tokenization(ictclas);
boolean isOK=tokenization.init();
String[] words=null;
if(isOK){
try {
words=tokenization.getPartedWordsWithoutSimbol(content);
} catch (UnsupportedEncodingException e) {
e.printStackTrace();
}
tokenization.finish();
}
return words;
}
public List<Document> getDocs() {
return docs;
}
public void setDocs(List<Document> docs) {
this.docs = docs;
}
public Map<String, Integer> getTermIndexMap() {
return termIndexMap;
}
public void setTermIndexMap(Map<String, Integer> termIndexMap) {
this.termIndexMap = termIndexMap;
}
public List<String> getTerms() {
return terms;
}
public void setTerms(List<String> terms) {
this.terms = terms;
}
}
package org.jazywoo.lda;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**Gibbs Sampling LDA
* @author jazywoo
*
*/
public class LdaModel {
private Corpus docSet;//处理的文档
private int[][] doc;// word index array,每个文本中每个词在字典indexToTermMap中的序号
private int V, K, M;// vocabulary size, topic number, document number
private int[][] z;// topic label array,每个文本的每个词对应的topic的编号
private float alpha; // doc-topic dirichlet prior parameter
private float beta; // topic-word dirichlet prior parameter
private int[][] nmk;// given document m, count times of topic k. M*K :给定document m中的词,每个topic的使用term词数
private int[][] nkt;// given topic k, count times of term t. K*V :给定topic k的每个term的使用词数
private int[] nmkSum;// Sum for each row in nmt,nmySum[m]=n:也就是文档m中word的个数为n
private int[] nktSum;// Sum for each row in nkt,nkt[k]=n:被指定给topic k的term/word的个数为n
// 两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布,
// 前者是k维(k为Topic总数)向量,后者是v维向量(v为词典中term总数)。
private double[][] theta;// Parameters for doc-topic distribution M*K
private double[][] phi;// Parameters for topic-word distribution K*V
private int iterations;// Times of iterations
private int saveStep;// The number of iterations between two saving
private int beginSaveIters;// Begin save model at this iteration
public LdaModel(LdaModel.ModelParameter parameter) {
alpha = parameter.alpha;
beta = parameter.beta;
iterations = parameter.iteration;
K = parameter.topicNum;
saveStep = parameter.saveStep;
beginSaveIters = parameter.beginSaveIters;
}
public void initModal(Corpus docSet1) {
this.docSet=docSet1;
M = docSet.getDocs().size();
V = docSet.getTerms().size();
nmk = new int [M][K];
nkt = new int[K][V];
nmkSum = new int[M];
nktSum = new int[K];
phi = new double[K][V];
theta = new double[M][K];
//初始化 每个文本中每个词在字典indexToTermMap中的序号
//initialize documents index array
doc = new int[M][];
for(int m = 0; m < M; m++){
//Notice the limit of memory
int N = docSet.getDocs().get(m).getWords().size();
doc[m] = new int[N];
for(int n = 0; n < N; n++){
doc[m][n] = docSet.getDocs().get(m).getWords().get(n);
}
}
// 初始化 每个文本的每个词对应的topic的编号
//initialize topic lable z for each word
z = new int[M][];
for(int m = 0; m < M; m++){
int N = docSet.getDocs().get(m).getWords().size();
z[m] = new int[N];
for(int n = 0; n < N; n++){
//初始时随机给文本中的每个单词分配主题z[m][n]_old
int initTopic = (int)(Math.random() * K);// From 0 to K - 1
z[m][n] = initTopic;
//number of words in doc m assigned to topic initTopic add 1
nmk[m][initTopic]++;
//number of terms doc[m][n] assigned to topic initTopic add 1
nkt[initTopic][doc[m][n]]++;
// total number of words assigned to topic initTopic add 1
nktSum[initTopic]++;
}
// total number of words in document m is N
nmkSum[m] = N;
}
}
public void inferenceModel() throws IOException {
if(iterations < saveStep + beginSaveIters){
System.err.println("Error: the number of iterations should be larger than " + (saveStep + beginSaveIters));
System.exit(0);
}
for(int i = 0; i < iterations; i++){
System.out.println("Iteration " + i);
if((i >= beginSaveIters) && (((i - beginSaveIters) % saveStep) == 0)){
//Saving the model
System.out.println("Saving model at iteration " + i +" ... ");
//Firstly update parameters
updateEstimatedParameters();
//Secondly print model variables
saveIteratedModel(i);
}
// z[][]每个文本的每个词对应的topic的编号
//Use Gibbs Sampling to update z[][]
for(int m = 0; m < M; m++){
int N = docSet.getDocs().get(m).getWords().size();
for(int n = 0; n < N; n++){
// Sample from p(z_i|z_-i, w)
int newTopic = sampleTopicZ(m, n);
z[m][n] = newTopic;
}
}
}
}
/**
* 初始时随机给文本中的每个单词分配主题z[m][n]_old,(这一步已经在初始化中完成)
* 然后统计每个主题z下出现term t的数量以及每个文档m下出现主题z中的词的数量,
* 每一轮计算p(z_i|z_-i, d, w),即排除当前词的主题分配,
* 根据其他所有词的主题分配估计当前词分配各个主题的概率。
* 当得到当前词属于所有主题z的概率分布后,
* 根据这个概率分布为该词sample一个新的主题z[m][n]_new。
* 然后用同样的方法不断更新下一个词的主题,
* 直到发现每个文档下Topic分布和每个Topic下词的分布收敛,算法停止,
* 输出待估计的参数和,最终每个单词的主题也同时得出。
* 实际应用中会设置最大迭代次数。每一次计算p(z_i|z_-i, d, w)的公式称为Gibbs updating rule.
* @param m
* @param n
* @return
*/
private int sampleTopicZ(int m, int n) {
// Sample from p(z_i|z_-i, w) using Gibbs upde rule
//Remove topic label for w_{m,n}
//首先当前词的主题分配
int oldTopic = z[m][n];
nmk[m][oldTopic]--;
nkt[oldTopic][doc[m][n]]--;
nmkSum[m]--;
nktSum[oldTopic]--;
//Compute p(z_i = k|z_-i, d, w)
//当得到当前文档,当前词属于所有主题z的概率分布
double [] p = new double[K];
for(int k = 0; k < K; k++){
//nkt-给定topic k的每个term的使用词数/nktSum-指定给topic k的term/word的个数
//nmk-给定document m的每个topic的使用词数/nmkSum-文档m中word的个数
//Gibbs抽样 P(z|w,alpha,beta) = P(w,z | alpha,beta) / P(w | alpha,beta)
p[k] = (nkt[k][doc[m][n]] + beta) / (nktSum[k] + V * beta)
* (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
//p[k]=phi[k][doc[m][n]]*theta[m][k];
}
//为该词分配一个新主题
//Sample a new topic label for w_{m, n} like roulette
//Compute cumulated probability for p
for(int k = 1; k < K; k++){
p[k] += p[k - 1];
}
double u = Math.random() * p[K - 1]; //p[] is unnormalised
int newTopic;
for(newTopic = 0; newTopic < K; newTopic++){
if(u < p[newTopic]){
break;
}
}
//Add new topic label for w_{m, n}
nmk[m][newTopic]++;
nkt[newTopic][doc[m][n]]++;
nmkSum[m]++;
nktSum[newTopic]++;
return newTopic;
}
/**估计 文档-主题theta参数,主题-词phi参数
* theta[m][k]表示第m个文档下的Topic分布,p(z_i|d_j)=p(z_i,d_j)/p(d_j)
* phi[k][t]表示第k个Topic下词的分布p(w_i|z_j)
*/
private void updateEstimatedParameters() {
for(int k = 0; k < K; k++){
for(int t = 0; t < V; t++){
phi[k][t] = (nkt[k][t] + beta) / (nktSum[k] + V * beta);
//给定topic k的每个term的使用词数/指定给topic k的term的个数
}
}
for(int m = 0; m < M; m++){
for(int k = 0; k < K; k++){
theta[m][k] = (nmk[m][k] + alpha) / (nmkSum[m] + K * alpha);
//给定document m的每个topic的使用词数/文档m中word的个数
}
}
}
/**用于保存分析的数据结果
* @param iters
* @param docSet
* @throws IOException
*/
public void saveIteratedModel(int iters) throws IOException {
// lda.params lda.phi lda.theta lda.tassign lda.twords
// lda.params
String resPath = "D:\\result\\";
String modelName = "lda_" + iters;
StringBuffer buf=new StringBuffer();
buf.append("alpha = " + alpha);
buf.append("beta = " + beta);
buf.append("topicNum = " + K);
buf.append("docNum = " + M);
buf.append("termNum = " + V);
buf.append("iterations = " + iterations);
buf.append("saveStep = " + saveStep);
buf.append("beginSaveIters = " + beginSaveIters);
BufferedWriter writer;
// writer = new BufferedWriter(new FileWriter(resPath
// + modelName + ".params.txt"));
// writer.write(buf.toString());
// writer.close();
//
// //两个隐含变量theta和phi分别表示第m个文档下的Topic分布和第k个Topic下词的分布,
// // lda.phi K*V
// writer = new BufferedWriter(new FileWriter(resPath
// + modelName + ".phi.txt"));
// for (int i = 0; i < K; i++) {
// for (int j = 0; j < V; j++) {
// writer.write("topic-word="+phi[i][j] + "\t");
// }
// writer.write("\n");
// }
// writer.close();
// // lda.theta M*K
// writer = new BufferedWriter(new FileWriter(resPath + modelName
// + ".theta.txt"));
// for (int i = 0; i < M; i++) {
// for (int j = 0; j < K; j++) {
// writer.write("doc-topic="+theta[i][j] + "\t");
// }
// writer.write("\n");
// }
// writer.close();
//
// // doc[m][n]每个文本中每个词在字典indexToTermMap中的序号
// // z[m][n]每个文本的每个词对应的topic的编号
// writer = new BufferedWriter(new FileWriter(resPath + modelName
// + ".wordIndex2topicIndex.txt"));
// for (int m = 0; m < M; m++) {
// for (int n = 0; n < doc[m].length; n++) {
// writer.write("doc[m][word]_index="+doc[m][n] + ":" +"z[m][word]_topicIndex="+ z[m][n] + "\t");
// }
// writer.write("\n");
// }
// writer.close();
// lda.twords phi[][] K*V
// 每个topic 前20个 出现概率高的,即 phi[i]大的
writer = new BufferedWriter(new FileWriter(resPath + modelName
+ ".topic_words.txt"));
int topNum = 20; // Find the top 20 topic words in each topic
for (int i = 0; i < K; i++) {
List<Integer> tWordsIndexArray = new ArrayList<Integer>();//topic的word的编号
for (int j = 0; j < V; j++) {
tWordsIndexArray.add(new Integer(j));
}
Collections.sort(tWordsIndexArray,
new LdaModel.ArrayDoubleComparator(phi[i]));//按phi[i],即出现概率大的
writer.write("topic " + i + ":\t");
for (int t = 0; t < topNum; t++) {
// writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))
// + " " + phi[i][tWordsIndexArray.get(t)] + " ;\t");
writer.write(docSet.getTerms().get(tWordsIndexArray.get(t))+" ");
}
writer.write("\n");
}
writer.close();
}
/**
* @author jazywoo
* 用于排序,比较phi[i],topic中词 出现概率高的
*/
public class ArrayDoubleComparator implements Comparator<Integer> {
private double[] sortProb; // Store probability of each word in topic k
public ArrayDoubleComparator(double[] sortProb) {
this.sortProb = sortProb;
}
@Override
public int compare(Integer o1, Integer o2) {// Sort topic word index according to the probability of each word
// in topic k
if (sortProb[o1] > sortProb[o2])
return -1;
else if (sortProb[o1] < sortProb[o2])
return 1;
else
return 0;
}
}
public static class ModelParameter{
public float alpha = 0.5f; //usual value is 50 / K
public float beta = 0.1f;//usual value is 0.1
public int topicNum = 10;
public int iteration = 100;
public int saveStep = 10;
public int beginSaveIters = 80;
}
}
package org.jazywoo.lda;
import java.io.IOException;
public class LDATest {
/**
* @param args
* @throws IOException
*/
public static void main(String[] args) throws IOException {
LdaModel.ModelParameter parameter=new LdaModel.ModelParameter();
LdaModel ldaModel=new LdaModel(parameter);
String path="D:\\zz";
Corpus docSet=new Corpus();
docSet.loadData(path);
ldaModel.initModal(docSet);
ldaModel.inferenceModel();
ldaModel.saveIteratedModel(parameter.iteration);
}
}