关闭

mallet 简析 1

标签: malletlda
582人阅读 评论(1) 收藏 举报
分类:

    

         最近一直在学习LDA 看来blei的C代码和matlab代码,matlab 的速度真是慢的不行,找到了MALLET ,想看详细分析,可惜网上大都是mallet的使用,自己就按照自己的理解把其过程简要的写出来。mallet 网址: http://mallet.cs.umass.edu/topics.php

          数据下载网址: http://www.nsf.gov/awardsearch/download.jsp

          其中topic-modeling-tool (http://code.google.com/p/topic-modeling-tool/ )是实现LDA过程的一个界面程序,配置好环境之后,运行界面如下:

    1、 在TopicModelingTool.java  的 m.invoke(null, passeArgs) 处将要处理的文档整合为mallet文件,文档的处理在自己没有选择stopword文件的时候去掉默认的stopword。 并对单词进行编号。

    2、在vectors2Topics.java 的 403 行的  training = InstanceList.load (new File(inputFile.value)); 读入数据,每篇文档以【word   wordid 】的方式。

           数据读入之后开始新建模型,初始化模型在该类别的422行: topicModel = new ParallelTopicModel (numTopics.value, alpha.value, beta.value);   初始化在parallelTopicModel.java 的118行,传入的参数为:numtopic 、sum-alpha、 beta.。 [初始化alpha=50,将此值赋给alphasum, 然后alpha = alpahsum / numtopics]

          初始化模型的过程:              
        this.data = new ArrayList<TopicAssignment>();
        this.topicAlphabet = topicAlphabet;
        this.numTopics = topicAlphabet.size();

        if (Integer.bitCount(numTopics) == 1) {
            // exact power of 2
            topicMask = numTopics - 1;
            topicBits = Integer.bitCount(topicMask);
        }
        else {
            // otherwise add an extra bit
            topicMask = Integer.highestOneBit(numTopics) * 2 - 1;
            topicBits = Integer.bitCount(topicMask);
        }


        this.alphaSum = alphaSum;
        this.alpha = new double[numTopics];
        Arrays.fill(alpha, alphaSum / numTopics);
        this.beta = beta;
        
        tokensPerTopic = new int[numTopics];
        
        formatter = NumberFormat.getInstance();
        formatter.setMaximumFractionDigits(5);

        logger.info("Mallet LDA: " + numTopics + " topics, " + topicBits + " topic bits, " +
                    Integer.toBinaryString(topicMask) + " topic mask");
   

      topicMask、topicBits 以及模型的alpha、alphasum、 beta 、tokensPerTopic (每个topic中单词个数)

3、将训练集加入到模型中开始训练:该java代码的427行。跳转到ParallelTopicModel.java 的 217 行。随机的为每个文档中的单词初始化一个topic ,并更新topic-counts,

          public void addInstances (InstanceList training) {

        alphabet = training.getDataAlphabet();        //   模型的单词集合
        numTypes = alphabet.size();               //  V
        
        betaSum = beta * numTypes;                 // V*beta

        Randoms random = null;
        if (randomSeed == -1) {
            random = new Randoms();
        }
        else {
            random = new Randoms(randomSeed);
        }

        for (Instance instance : training) {
            FeatureSequence tokens = (FeatureSequence) instance.getData();//文档中的单词以及标号
            LabelSequence topicSequence =
                new LabelSequence(topicAlphabet, new int[ tokens.size() ]); //初始的时候全都归到topic0中
            
            int[] topics = topicSequence.getFeatures();
            for (int position = 0; position < topics.length; position++) {

                int topic = random.nextInt(numTopics);   //此处也是随机的赋予了一个标号
                topics[position] = topic;
                
            }

            TopicAssignment t = new TopicAssignment(instance, topicSequence);
            data.add(t);
        }
        
        buildInitialTypeTopicCounts();
        initializeHistograms();
    }


4、开始采样: Vectors2Topics  的line 453 :  topicModel.estimate();

      estimate()方法在  ParallelTopicModel  只有一个线程举例:

      line 746:

        runnables[0] = new WorkerRunnable(numTopics, alpha, alphaSum, beta,   random, data,  typeTopicCounts, tokensPerTopic,  offset, docsPerThread); 

     将参数传入模型中。
    Line  862 

      runnables[0].run();  

     WoekerRunnable.java  line 275: 针对每个文档采样,其中tokenSequence 是 文档中的单词序列,topicSequence是文档中单词所属类别的标号。


要去上课了,晚上回来再继续写采样部分。


1
0

查看评论
* 以上用户言论只代表其个人观点,不代表CSDN网站的观点或立场
    个人资料
    • 访问:26074次
    • 积分:1547
    • 等级:
    • 排名:千里之外
    • 原创:130篇
    • 转载:0篇
    • 译文:0篇
    • 评论:17条
    最新评论