package LDA;
/**
* LDA GibbsSampling
*
* @author: Liang Yao
*/
public class LDA {
int[][] documents;
int V;
int K;
double alpha;
double beta;
int[][] z;
int[][] nw;
int[][] nd;
int[] nwsum;
int[] ndsum;
int iterations;
public LDA(int[][] documents, int V) {
this.documents = documents;
this.V = V;
}
public void initialState() {
int D = documents.length;
nw = new int[V][K];
nd = new int[D][K];
nwsum = new int[K];
ndsum = new int[D];
z = new int[D][];
for (int d = 0; d < D; d++) {
int Nd = documents[d].length;
z[d] = new int[Nd];
for (int n = 0; n < Nd; n++) {
int topic = (int) (Math.random() * K);
z[d][n] = topic;
nw[documents[d][n]][topic]++;
nd[d][topic]++;
nwsum[topic]++;
}
ndsum[d] = Nd;
}
}
public void markovChain(int K, double alpha, double beta, int iterations) {
this.K = K;
this.alpha = alpha;
this.beta = beta;
this.iterations = iterations;
initialState();
for (int i = 0; i < this.iterations; i++) {
System.out.println(i);
gibbs();
}
}
public void gibbs() {
for (int d = 0; d < z.length; d++) {
for (int n = 0; n < z[d].length; n++) {
int topic = sampleFullConditional(d, n);
z[d][n] = topic;
}
}
}
int sampleFullConditional(int d, int n) {
int topic = z[d][n];
nw[documents[d][n]][topic]--;
nd[d][topic]--;
nwsum[topic]--;
ndsum[d]--;
double[] p = new double[K];
for (int k = 0; k < K; k++) {
p[k] = (nd[d][k] + alpha) / (ndsum[d] + K * alpha)
* (nw[documents[d][n]][k] + beta) / (nwsum[k] + V * beta);
}
for (int k = 1; k < K; k++) {
p[k] += p[k - 1];
}
double u = Math.random() * p[K - 1];
for (int t = 0; t < K; t++) {
if (u < p[t]) {
topic = t;
break;
}
}
nw[documents[d][n]][topic]++;//nw[i][j] : num of word i assigned to topic j
nd[d][topic]++; //nd[i][j]: num of doc i assigned to topic j
nwsum[topic]++; //nwsum[i]:num of topic i;
ndsum[d]++;// ndsum[i]: length of doc i
return topic;
}
public double[][] estimateTheta() {
double[][] theta = new double[documents.length][K];
for (int d = 0; d < documents.length; d++) {
for (int k = 0; k < K; k++) {
theta[d][k] = (nd[d][k] + alpha) / (ndsum[d] + K * alpha);
}
}
return theta;
}
public double[][] estimatePhi() {
double[][] phi = new double[K][V];
for (int k = 0; k < K; k++) {
for (int w = 0; w < V; w++) {
phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);
}
}
return phi;
}
public static void main(String[] args) {
// words in documents
int[][] documents = {
{ 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 6 },
{ 2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2 },
{ 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 0 },
{ 5, 6, 6, 2, 3, 3, 6, 5, 6, 2, 2, 6, 5, 6, 6, 6, 0 },
{ 2, 2, 4, 4, 4, 4, 1, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 0 },
{ 5, 4, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2 },
};
// vocabulary
int V = 7;
// # topics
int K = 3;
// good values alpha = 2, beta = .5
double alpha = 0.1;
double beta = 0.1;
int iterations = 1000;
LDA lda = new LDA(documents, V);
lda.markovChain(K, alpha, beta, iterations);
double[][] theta = lda.estimateTheta();
double[][] phi = lda.estimatePhi();
for (int d = 0; d < lda.z.length; d++) {
for (int m = 0; m < lda.z[d].length; m++) {
System.out.print(lda.z[d][m] + " ");
}
System.out.println();
}
System.out.println();
for (int m = 0; m < theta.length; m++) {
System.out.print(m + "\t");
for (int k = 0; k < theta[m].length; k++) {
System.out.print(theta[m][k] + " ");
// System.out.print(shadeDouble(theta[m][k], 1) + " ");
}
System.out.println();
}
System.out.println();
for (int k = 0; k < phi.length; k++) {
System.out.print(k + "\t");
for (int w = 0; w < phi[k].length; w++) {
System.out.print(phi[k][w] + " ");
// System.out.print(shadeDouble(phi[k][w], 1) + " ");
}
System.out.println();
}
}
}
简化版的LDA源码
最新推荐文章于 2024-05-17 08:57:45 发布