简化版的LDA源码

package LDA;

/**
 * LDA GibbsSampling
 * 
 * @author: Liang Yao
 */
public class LDA {

	int[][] documents;

	int V;

	int K;

	double alpha;

	double beta;

	int[][] z;

	int[][] nw;

	int[][] nd;

	int[] nwsum;

	int[] ndsum;

	int iterations;

	public LDA(int[][] documents, int V) {

		this.documents = documents;
		this.V = V;
	}

	public void initialState() {

		int D = documents.length;
		nw = new int[V][K];
		nd = new int[D][K];
		nwsum = new int[K];
		ndsum = new int[D];

		z = new int[D][];
		for (int d = 0; d < D; d++) {

			int Nd = documents[d].length;

			z[d] = new int[Nd];

			for (int n = 0; n < Nd; n++) {
				int topic = (int) (Math.random() * K);

				z[d][n] = topic;

				nw[documents[d][n]][topic]++;

				nd[d][topic]++;

				nwsum[topic]++;
			}
			ndsum[d] = Nd;
		}

	}

	public void markovChain(int K, double alpha, double beta, int iterations) {

		this.K = K;
		this.alpha = alpha;
		this.beta = beta;
		this.iterations = iterations;

		initialState();

		for (int i = 0; i < this.iterations; i++) {
			System.out.println(i);
			gibbs();
		}
	}

	public void gibbs() {

		for (int d = 0; d < z.length; d++) {
			for (int n = 0; n < z[d].length; n++) {

				int topic = sampleFullConditional(d, n);
				z[d][n] = topic;

			}
		}
	}

	int sampleFullConditional(int d, int n) {

		int topic = z[d][n];
		nw[documents[d][n]][topic]--;
		nd[d][topic]--;
		nwsum[topic]--;
		ndsum[d]--;

		double[] p = new double[K];

		for (int k = 0; k < K; k++) {

			p[k] = (nd[d][k] + alpha) / (ndsum[d] + K * alpha)
					* (nw[documents[d][n]][k] + beta) / (nwsum[k] + V * beta);
		}
		for (int k = 1; k < K; k++) {
			p[k] += p[k - 1];
		}
		double u = Math.random() * p[K - 1];
		for (int t = 0; t < K; t++) {
			if (u < p[t]) {
				topic = t;
				break;
			}
		}
		nw[documents[d][n]][topic]++;//nw[i][j] : num of word i assigned to topic j
		nd[d][topic]++; //nd[i][j]: num of doc i assigned to topic j
		nwsum[topic]++; //nwsum[i]:num of topic i;
		ndsum[d]++;// ndsum[i]: length of doc i
		return topic;

	}

	public double[][] estimateTheta() {
		double[][] theta = new double[documents.length][K];
		for (int d = 0; d < documents.length; d++) {
			for (int k = 0; k < K; k++) {
				theta[d][k] = (nd[d][k] + alpha) / (ndsum[d] + K * alpha);
			}
		}
		return theta;
	}

	public double[][] estimatePhi() {
		double[][] phi = new double[K][V];
		for (int k = 0; k < K; k++) {
			for (int w = 0; w < V; w++) {
				phi[k][w] = (nw[w][k] + beta) / (nwsum[k] + V * beta);
			}
		}
		return phi;
	}

	public static void main(String[] args) {

		// words in documents
		int[][] documents = {
				{ 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 1, 4, 3, 2, 3, 6 },
				{ 2, 2, 4, 2, 4, 2, 2, 2, 2, 4, 2, 2 },
				{ 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 1, 6, 5, 6, 0, 0 },
				{ 5, 6, 6, 2, 3, 3, 6, 5, 6, 2, 2, 6, 5, 6, 6, 6, 0 },
				{ 2, 2, 4, 4, 4, 4, 1, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 0 },
				{ 5, 4, 2, 3, 4, 5, 6, 6, 5, 4, 3, 2 },

		};

		// vocabulary
		int V = 7;
		// # topics
		int K = 3;
		// good values alpha = 2, beta = .5
		double alpha = 0.1;
		double beta = 0.1;

		int iterations = 1000;

		LDA lda = new LDA(documents, V);
		lda.markovChain(K, alpha, beta, iterations);

		double[][] theta = lda.estimateTheta();
		double[][] phi = lda.estimatePhi();

		for (int d = 0; d < lda.z.length; d++) {
			for (int m = 0; m < lda.z[d].length; m++) {
				System.out.print(lda.z[d][m] + " ");
			}
			System.out.println();

		}
		System.out.println();
		for (int m = 0; m < theta.length; m++) {
			System.out.print(m + "\t");
			for (int k = 0; k < theta[m].length; k++) {
				System.out.print(theta[m][k] + " ");
				// System.out.print(shadeDouble(theta[m][k], 1) + " ");
			}
			System.out.println();
		}

		System.out.println();
		for (int k = 0; k < phi.length; k++) {
			System.out.print(k + "\t");
			for (int w = 0; w < phi[k].length; w++) {
				System.out.print(phi[k][w] + " ");
				// System.out.print(shadeDouble(phi[k][w], 1) + " ");
			}
			System.out.println();
		}
	}
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值