Java实现LDA在线变分

该博客介绍了一种使用Java实现的LDA(Latent Dirichlet Allocation)在线变分算法。作者提供了代码实现,包括初始化参数、E步(变分主题推断)和M步(更新话题分布)等关键步骤。博客通过具体的数据集进行演示,并展示了如何迭代更新模型参数。
摘要由CSDN通过智能技术生成

原理可以看徐亦达的视频和白板推导的视频,论文可以看Online Learning for Latent Dirichlet Allocation, 我没有跟新alpha 如果跟新 可以看http://jonathan-huang.org/research/dirichlet/dirichlet.pdf或者源码

代码

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;

import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.util.random.PoissonSampler;

import breeze.stats.distributions.Gamma;
import breeze.stats.distributions.RandBasis;
import breeze.stats.distributions.ThreadLocalRandomGenerator;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.reflect.ClassManifestFactory;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction0;

/**
 * 
 */

public class JavaOnlineVLDA3
{

    private static ClassTag<Double> tagDouble = ClassManifestFactory.classType( Double.class );
    private boolean optimizeDocConcentration = false;
    private int iteration = 0;
    private int k = 2;
    private long corpusSize;
    private int vocabSize;
    private Random randomGenerator;

    private double gammaShape = 100;
    private double tau0 = 1024;
    private double kappa = 0.51;
    private Vector alpha = Vectors.dense( new double[]{
            0
    } );
    private double eta;
    private long seed = 10;

    /**
     * 论文中的lambda
     */
    private Matrix lambda;
    private double miniBatchFraction = 0.5;
    private static int maxIterations = 4;

    public static void main( String[] args )
    {
        List<Tuple2<Long, Vector>> data = Arrays.asList( new Tuple2<Long, Vector>( 0L, Vectors.dense( new double[]{
                1, 1, 0, 3
        } ) ), new Tuple2<Long, Vector>( 1L, Vectors.dense( new double[]{
                1, 0, 0, 3
        } ) ), new Tuple2<Long, Vector>( 3L, Vectors.dense( new double[]{
                1, 0, 2, 3
        } ) ), new Tuple2<Long, Vector>( 2L, Vectors.dense( new double[]{
                0, 2, 0, 3
        } ) ) );

        testGaoLDA( data );
    }

    private static void testGaoLDA( List<Tuple2<Long, Vector>> data )
    {
        JavaOnlineVLDA3 state = new JavaOnlineVLDA3( );
        state.initialize( data );
        int iter = 0;

        while ( iter < maxIterations )
        {
            state = state.next( data );
            iter++;
        }

        System.out.println( state.lambda.transpose( ) );
        System.out.println( state.alpha );

        System.out.println( "done" );
    }

    private void initialize( List<Tuple2<Long, Vector>> doc )
    {
        this.corpusSize = doc.size( );
        this.vocabSize = doc.get( 0 )._2.size( );
        // init alpha
        double[] ds = new double[k];
        Arrays.fill( ds, 0, k, 1.0 / k );
        this.alpha = Vectors.dense( ds );

        // init beta
        this.eta = 1.0 / k;

        this.randomGenerator = new Random( seed );

        //init lambda
        this.lambda = getGammaMatrix( k, vocabSize );

    }

    public JavaOnlineVLDA3 next( List<Tuple2<Long, Vector>> data )
    {
        List<Tuple2<Long, Vector>> batch = sample( data, miniBatchFraction, seed );
        if ( batch.isEmpty( ) )
        {
            return this;
        }
        return submitMiniBatch( batch );
    }

    private JavaOnlineVLDA3 submitMiniBatch( List<Tuple2<Long, Vector>> data )
    {
        iteration = iteration + 1;
        
        //转会真正的值
        Matrix expElogbeta = expMatrixWithT( dirichletExpectation( lambda ) );
        List<Tuple2<Matrix, List<Vector>>> stats = new ArrayList<Tuple2<Matrix, List&l

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值