原理可以看徐亦达的视频和白板推导的视频,论文可以看Online Learning for Latent Dirichlet Allocation, 我没有跟新alpha 如果跟新 可以看http://jonathan-huang.org/research/dirichlet/dirichlet.pdf或者源码
代码
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.stream.Collectors;
import org.apache.commons.collections.IteratorUtils;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.apache.spark.mllib.linalg.Matrices;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.SparseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.util.random.PoissonSampler;
import breeze.stats.distributions.Gamma;
import breeze.stats.distributions.RandBasis;
import breeze.stats.distributions.ThreadLocalRandomGenerator;
import scala.Tuple2;
import scala.collection.JavaConversions;
import scala.reflect.ClassManifestFactory;
import scala.reflect.ClassTag;
import scala.runtime.AbstractFunction0;
/**
*
*/
public class JavaOnlineVLDA3
{
private static ClassTag<Double> tagDouble = ClassManifestFactory.classType( Double.class );
private boolean optimizeDocConcentration = false;
private int iteration = 0;
private int k = 2;
private long corpusSize;
private int vocabSize;
private Random randomGenerator;
private double gammaShape = 100;
private double tau0 = 1024;
private double kappa = 0.51;
private Vector alpha = Vectors.dense( new double[]{
0
} );
private double eta;
private long seed = 10;
/**
* 论文中的lambda
*/
private Matrix lambda;
private double miniBatchFraction = 0.5;
private static int maxIterations = 4;
public static void main( String[] args )
{
List<Tuple2<Long, Vector>> data = Arrays.asList( new Tuple2<Long, Vector>( 0L, Vectors.dense( new double[]{
1, 1, 0, 3
} ) ), new Tuple2<Long, Vector>( 1L, Vectors.dense( new double[]{
1, 0, 0, 3
} ) ), new Tuple2<Long, Vector>( 3L, Vectors.dense( new double[]{
1, 0, 2, 3
} ) ), new Tuple2<Long, Vector>( 2L, Vectors.dense( new double[]{
0, 2, 0, 3
} ) ) );
testGaoLDA( data );
}
private static void testGaoLDA( List<Tuple2<Long, Vector>> data )
{
JavaOnlineVLDA3 state = new JavaOnlineVLDA3( );
state.initialize( data );
int iter = 0;
while ( iter < maxIterations )
{
state = state.next( data );
iter++;
}
System.out.println( state.lambda.transpose( ) );
System.out.println( state.alpha );
System.out.println( "done" );
}
private void initialize( List<Tuple2<Long, Vector>> doc )
{
this.corpusSize = doc.size( );
this.vocabSize = doc.get( 0 )._2.size( );
// init alpha
double[] ds = new double[k];
Arrays.fill( ds, 0, k, 1.0 / k );
this.alpha = Vectors.dense( ds );
// init beta
this.eta = 1.0 / k;
this.randomGenerator = new Random( seed );
//init lambda
this.lambda = getGammaMatrix( k, vocabSize );
}
public JavaOnlineVLDA3 next( List<Tuple2<Long, Vector>> data )
{
List<Tuple2<Long, Vector>> batch = sample( data, miniBatchFraction, seed );
if ( batch.isEmpty( ) )
{
return this;
}
return submitMiniBatch( batch );
}
private JavaOnlineVLDA3 submitMiniBatch( List<Tuple2<Long, Vector>> data )
{
iteration = iteration + 1;
//转会真正的值
Matrix expElogbeta = expMatrixWithT( dirichletExpectation( lambda ) );
List<Tuple2<Matrix, List<Vector>>> stats = new ArrayList<Tuple2<Matrix, List&l