Gibbs LDA java实现

1.偏文、偏理的故事


    某学校高一年级有6个班级,每个班级各有一定数量的学生,3班有几个同学数学成绩很好,拿过省奥赛奖。现在教育局要来该校听数学课,学校应该安排听课老师听哪个班的课?显然是3班,因为3班有几个数学特别厉害的同学,所以3班数学强一点,至少看起来数学强一点.这里,我们把"偏理"称为3班的特点。同样,2班和4班有很多同学的语文成绩很好,他们的作文都曾被文学报刊发表过,我们可以说”偏文“是2班和4班的特点。又假如5班和6班的同学在校篮球赛上进了决赛,我们可以说5班和6班”偏体育“。如果教育局来该校听某种课程,我们就可以安排他们去有该课程"特点"的班级里听。
    在这里,原来的班级结构是只有两层,即学生层,和班级层,每个学生都有指定的班级。我们为了区分每个班级的特点,在学生和班级之间又加了一层,特点层,即”偏文“,”偏理“,”偏体育“。这个特点层就是对LDA最直观的理解。接着上面偏文偏理的故事,3班除了几个同学数学好,另外还有一部分同学思想品德很好,因多次扶老奶奶过马路而上新闻,我们同样也可以说3班同学思想品德很好。这样3班的特点就不只1个了,这里我们提出分布的概念,即每个班级可能有多个特点,只是有的特点对应的学生多,有的特点对应的学生少,我们选对应学生多的特点作为这个班级的主要特点。每个特点同样也对应多个同学也是一种分布,比如,”偏理“包含拿过奥赛奖的同学,也包含期末考试数学考满分的同学。到这里,班级包含多个特点,每个特点又包含多个学生,LDA的主要结构就是这样。
    对于一片文档,我们怎么区分这篇文档是属于那个类别?参照偏文偏理的例子, 我们可以把文档想象成班级,word想象成学生。例如某篇文档的单词中,银行,汇率,股票,下跌等次大量重复出现,那么该篇文档很有可能就是写经济的,我们可以把这篇文档归为经济类。如果某篇文档里面含有,詹姆斯,科比,扣篮,犯规等词,那么这篇文章很有可能是体育类。当然这种分类不一定是单一的,有可能一个文章有多个主题。
三层结构如下:

doct: |           doc1                     doc2                     doc3 ...

topic:|      t1     t2      t3...      t1   t2    t3            t1  t2  t3...

word: |w1 w5 w8... w6 w2 w3....      w3,w5... 

最终要求的,doc下面的topics分布和topic下的words分布.LDA原理见原论文,不赘述.
输出文件中有各种分布:
topic ~ words
doc ~ topics
topic ~ docs
详见JGibbLDA的输出文件

2.Gibbs LDA代码结构


    第一次读代码时把lda分成了两部分,即训练部分和推测部分,训练部分训练出来模型,即topic下面的words分布等,推测部分是用训练出的模型推测新的文章。后来发现推测部分也是一种训练,只是参考了已训练好的结果再训练.如果推测的文件数据量大于参考的数据量,那么这个推测集推测出来的结果,可以当成新的模型,更为准确。训练过程和推测过程的结果类型是完全相同的,包含各个完整的分布,详见JGibbLDA的输出文件

代码除了读入,保存之类的,核心代码不到200行.LDA 结构与代码如下:

  • 1.预处理:
    去停词表,去noise词,低频词等等.
  • 2.Estimate:推测过程
package jgibblda;

import java.io.File;
import java.util.Vector;

public class Estimator {
    
    // output model
    protected Model trnModel;
    LDACmdOption option;
    
    public boolean init(LDACmdOption option){
        this.option = option;
        trnModel = new Model();
        
        if (option.est){
            if (!trnModel.initNewModel(option))
                return false;
            trnModel.data.localDict.writeWordMap(option.dir + File.separator + option.wordMapFileName);
        }
        else if (option.estc){
            if (!trnModel.initEstimatedModel(option))
                return false;
        }
        
        return true;
    }
    
    public void estimate(){
        System.out.println("Sampling " + trnModel.niters + " iteration!");
        
        int lastIter = trnModel.liter;
        for (trnModel.liter = lastIter + 1; trnModel.liter < trnModel.niters + lastIter; trnModel.liter++){
            System.out.println("Iteration " + trnModel.liter + " ...");
            
            // for all z_i
            for (int m = 0; m < trnModel.M; m++){               
                for (int n = 0; n < trnModel.data.docs[m].length; n++){
                    // z_i = z[m][n]
                    // sample from p(z_i|z_-i, w)
                    int topic = sampling(m, n);
                    trnModel.z[m].set(n, topic);
                }// end for each word
            }// end for each document
            
            if (option.savestep > 0){
                if (trnModel.liter % option.savestep == 0){
                    System.out.println("Saving the model at iteration " + trnModel.liter + " ...");
                    computeTheta();
                    computePhi();
                    trnModel.saveModel("model-" + Conversion.ZeroPad(trnModel.liter, 5));
                }
            }
        }// end iterations      
        
        System.out.println("Gibbs sampling completed!\n");
        System.out.println("Saving the final model!\n");
        computeTheta();
        computePhi();
        trnModel.liter--;
        trnModel.saveModel("model-final");
    }
    
    /**
     * Do sampling
     * @param m document number
     * @param n word number
     * @return topic id
     */
    public int sampling(int m, int n){
        // remove z_i from the count variable
        int topic = trnModel.z[m].get(n);
        int w = trnModel.data.docs[m].words[n];
        
        trnModel.nw[w][topic] -= 1;
        trnModel.nd[m][topic] -= 1;
        trnModel.nwsum[topic] -= 1;
        trnModel.ndsum[m] -= 1;
        
        double Vbeta = trnModel.V * trnModel.beta;
        double Kalpha = trnModel.K * trnModel.alpha;
        
        //do multinominal sampling via cumulative method
        for (int k = 0; k < trnModel.K; k++){
            trnModel.p[k] = (trnModel.nw[w][k] + trnModel.beta)/(trnModel.nwsum[k] + Vbeta) *
                    (trnModel.nd[m][k] + trnModel.alpha)/(trnModel.ndsum[m] + Kalpha);
        }
        
        // cumulate multinomial parameters
        for (int k = 1; k < trnModel.K; k++){
            trnModel.p[k] += trnModel.p[k - 1];
        }
        
        // scaled sample because of unnormalized p[]
        double u = Math.random() * trnModel.p[trnModel.K - 1];              // 这一段没懂
        
        for (topic = 0; topic < trnModel.K; topic++){
            if (trnModel.p[topic] > u) //sample topic w.r.t distribution p
                break;
        }
        
        // add newly estimated z_i to count variables
        
        trnModel.nw[w][topic] += 1;
        trnModel.nd[m][topic] += 1;
        trnModel.nwsum[topic] += 1;
        trnModel.ndsum[m] += 1;
        return topic;
    }
    
    public void computeTheta(){
        for (int m = 0; m < trnModel.M; m++){
            for (int k = 0; k < trnModel.K; k++){
                trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha);
            }
        }
    }
    
    public void computePhi(){
        for (int k = 0; k < trnModel.K; k++){
            for (int w = 0; w < trnModel.V; w++){
                trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta);
            }
        }
    }
}

Sampling()部分里面,以下代码没懂。每个word所属的topic初始化时是随机分配的,中间迭代的时候,为什么还是随机的?
p[k]在这是所有topic分布之和,然后随机一个数乘以这个和,得到u。这里u可以理解成word可以取到topic的范围。
然后返回第一个比u大的p[k]的下标k,这里k代表第k个topic,还是前k个topics?
最终要求的不是word只对应某个topic,而是word下的topic分布,和topic下的分布,下一遍看代码要参考分布理解这一段。

// cumulate multinomial parameters
        for (int k = 1; k < trnModel.K; k++){
            trnModel.p[k] += trnModel.p[k - 1];
        }
        
        // scaled sample because of unnormalized p[]
        double u = Math.random() * trnModel.p[trnModel.K - 1];              // 这一段没懂
        
        for (topic = 0; topic < trnModel.K; topic++){
            if (trnModel.p[topic] > u) //sample topic w.r.t distribution p
                break;
        }

3.Inference: 推测过程


package jgibblda;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.util.StringTokenizer;
import java.util.Vector;

public class Inferencer {   
    // Train model
    public Model trnModel;
    public Dictionary globalDict;
    private LDACmdOption option;
    
    private Model newModel;
    public int niters = 100;
    
    //-----------------------------------------------------
    // Init method
    //-----------------------------------------------------
    public boolean init(LDACmdOption option){
        this.option = option;
        trnModel = new Model();
        
        if (!trnModel.initEstimatedModel(option))
            return false;       
        
        globalDict = trnModel.data.localDict;
        computeTrnTheta();
        computeTrnPhi();
        
        return true;
    }
    
    //inference new model ~ getting data from a specified dataset
    public Model inference( LDADataset newData){
        System.out.println("init new model");
        Model newModel = new Model();       
        
        newModel.initNewModel(option, newData, trnModel);       
        this.newModel = newModel;       
        
        System.out.println("Sampling " + niters + " iteration for inference!");     
        for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){
            //System.out.println("Iteration " + newModel.liter + " ...");
            
            // for all newz_i
            for (int m = 0; m < newModel.M; ++m){
                for (int n = 0; n < newModel.data.docs[m].length; n++){
                    // (newz_i = newz[m][n]
                    // sample from p(z_i|z_-1,w)
                    int topic = infSampling(m, n);
                    newModel.z[m].set(n, topic);
                }
            }//end foreach new doc
            
        }// end iterations
        
        System.out.println("Gibbs sampling for inference completed!");
        
        computeNewTheta();
        computeNewPhi();
        newModel.liter--;
        return this.newModel;
    }
    
    public Model inference(String [] strs){
        //System.out.println("inference");
        Model newModel = new Model();
        
        //System.out.println("read dataset");
        LDADataset dataset = LDADataset.readDataSet(strs, globalDict);
        
        return inference(dataset);
    }
    
    //inference new model ~ getting dataset from file specified in option
    public Model inference(){   
        //System.out.println("inference");
        
        newModel = new Model();
        if (!newModel.initNewModel(option, trnModel)) return null;
        
        System.out.println("Sampling " + niters + " iteration for inference!");
        
        for (newModel.liter = 1; newModel.liter <= niters; newModel.liter++){
            //System.out.println("Iteration " + newModel.liter + " ...");
            
            // for all newz_i
            for (int m = 0; m < newModel.M; ++m){
                for (int n = 0; n < newModel.data.docs[m].length; n++){
                    // (newz_i = newz[m][n]
                    // sample from p(z_i|z_-1,w)
                    int topic = infSampling(m, n);
                    newModel.z[m].set(n, topic);
                }
            }//end foreach new doc
            
        }// end iterations
        
        System.out.println("Gibbs sampling for inference completed!");      
        System.out.println("Saving the inference outputs!");
        
        computeNewTheta();
        computeNewPhi();
        newModel.liter--;
        newModel.saveModel(newModel.dfile + "." + newModel.modelName);      
        
        return newModel;
    }
    
    /**
     * do sampling for inference
     * m: document number
     * n: word number?
     */
    protected int infSampling(int m, int n){
        // remove z_i from the count variables
        int topic = newModel.z[m].get(n);
        int _w = newModel.data.docs[m].words[n];
        int w = newModel.data.lid2gid.get(_w);
        newModel.nw[_w][topic] -= 1;
        newModel.nd[m][topic] -= 1;
        newModel.nwsum[topic] -= 1;
        newModel.ndsum[m] -= 1;
        
        double Vbeta = trnModel.V * newModel.beta;
        double Kalpha = trnModel.K * newModel.alpha;
        
        // do multinomial sampling via cummulative method       
        for (int k = 0; k < newModel.K; k++){           
            newModel.p[k] = (trnModel.nw[w][k] + newModel.nw[_w][k] + newModel.beta)/(trnModel.nwsum[k] +  newModel.nwsum[k] + Vbeta) *
                    (newModel.nd[m][k] + newModel.alpha)/(newModel.ndsum[m] + Kalpha);
        }
        
        // cummulate multinomial parameters
        for (int k = 1; k < newModel.K; k++){
            newModel.p[k] += newModel.p[k - 1];
        }
        
        // scaled sample because of unnormalized p[]
        double u = Math.random() * newModel.p[newModel.K - 1];     
        
        for (topic = 0; topic < newModel.K; topic++){
            if (newModel.p[topic] > u)
                break;
        }
        
        // add newly estimated z_i to count variables
        newModel.nw[_w][topic] += 1;
        newModel.nd[m][topic] += 1;
        newModel.nwsum[topic] += 1;
        newModel.ndsum[m] += 1;
        
        return topic;
    }
    
    protected void computeNewTheta(){
        for (int m = 0; m < newModel.M; m++){
            for (int k = 0; k < newModel.K; k++){
                newModel.theta[m][k] = (newModel.nd[m][k] + newModel.alpha) / (newModel.ndsum[m] + newModel.K * newModel.alpha);
            }//end foreach topic
        }//end foreach new document
    }
    
    protected void computeNewPhi(){
        for (int k = 0; k < newModel.K; k++){
            for (int _w = 0; _w < newModel.V; _w++){
                Integer id = newModel.data.lid2gid.get(_w);
                
                if (id != null){
                    newModel.phi[k][_w] = (trnModel.nw[id][k] + newModel.nw[_w][k] + newModel.beta) / (newModel.nwsum[k] + newModel.nwsum[k] + trnModel.V * newModel.beta);
                }
            }//end foreach word
        }// end foreach topic
    }
    
    protected void computeTrnTheta(){
        for (int m = 0; m < trnModel.M; m++){
            for (int k = 0; k < trnModel.K; k++){
                trnModel.theta[m][k] = (trnModel.nd[m][k] + trnModel.alpha) / (trnModel.ndsum[m] + trnModel.K * trnModel.alpha);
            }
        }
    }
    
    protected void computeTrnPhi(){
        for (int k = 0; k < trnModel.K; k++){
            for (int w = 0; w < trnModel.V; w++){
                trnModel.phi[k][w] = (trnModel.nw[w][k] + trnModel.beta) / (trnModel.nwsum[k] + trnModel.V * trnModel.beta);
            }
        }
    }
}

4.数据可视化和输出
完整代码可参考原版JGibblda

posted on 2015-04-23 21:16 cynorr 阅读( ...) 评论( ...) 编辑 收藏

转载于:https://www.cnblogs.com/cyno/p/4451804.html

在该作者(http://blog.csdn.net/yangliuy/article/details/8457329)的基础上添加中文分词,实现中文主题发现。相关的文档请到原版作者查阅。谢谢! 也许是待分析的语料太少,效果好像不是很好。 目前的语料输出结果如下: topic 0 : 等 0.010036719031631947 这样 0.010036719031631947 但 0.010036719031631947 下 0.007588739041239023 很难 0.007588739041239023 一个 0.007588739041239023 于 0.007588739041239023 亿元 0.0051407585851848125 目前 0.0051407585851848125 带动 0.0051407585851848125 上 0.0051407585851848125 提出 0.0051407585851848125 地 0.0051407585851848125 做 0.0051407585851848125 技术 0.0051407585851848125 水平 0.0051407585851848125 不 0.0051407585851848125 作 0.0051407585851848125 其实 0.0051407585851848125 市场 0.0051407585851848125 topic 1 : 在 0.02684444561600685 和 0.023288888856768608 对 0.012622222304344177 进行 0.010844443924725056 为 0.009066666476428509 与 0.009066666476428509 选择 0.009066666476428509 还是 0.009066666476428509 其中 0.0072888885624706745 主要 0.0072888885624706745 而 0.0072888885624706745 只有 0.0072888885624706745 看 0.0072888885624706745 遇到 0.0072888885624706745 3 0.005511111114174128 把 0.005511111114174128 也 0.005511111114174128 注意 0.005511111114174128 时间 0.005511111114174128 一种 0.005511111114174128 topic 2 : 英语 0.012685983441770077 考生 0.012685983441770077 可以 0.011119811795651913 词汇 0.009553641080856323 句子 0.009553641080856323 时 0.007987470366060734 就 0.007987470366060734 考试 0.007987470366060734 阅读 0.007987470366060734 写作 0.007987470366060734 上 0.006421299651265144 才能 0.006421299651265144 很多 0.006421299651265144 理解 0.006421299651265144 一些 0.006421299651265144 复习 0.006421299651265144 基础 0.006421299651265144 翻译 0.006421299651265144 大家 0.006421299651265144 根据 0.006421299651265144 topic 3 : 等 0.01035533007234335 公司 0.008324872702360153 网上 0.008324872702360153 法院 0.008324872702360153 和 0.0062944162636995316 迪 0.0062944162636995316 志 0.0062944162636995316 经营 0.0062944162636995316 易趣网 0.0062944162636995316 进 0.0062944162636995316 在 0.004263959359377623 该 0.004263959359377623 其 0.004263959359377623 拥有 0.004263959359377623 5 0.004263959359377623 记者 0.004263959359377623 巨头 0.004263959359377623 直接 0.004263959359377623 研究所 0.004263959359377623 文渊阁 0.004263959359377623 topic 4 : 来 0.010161090642213821 中国 0.010161090642213821 之后 0.007682775612920523 主要 0.007682775612920523 2005年 0.005204460583627224 生产 0.005204460583627224 发展 0.005204460583627224 消费 0.005204460583627224 企业 0.005204460583627224 能 0.005204460583627224 这是 0.005204460583627224 还得 0.005204460583627224 工业 0.005204460583627224 百强 0.005204460583627224 发布 0.005204460583627224 各项 0.005204460583627224 药 0.005204460583627224 会上 0.005204460583627224 汽车 0.002726146252825856 专用汽车 0.002726146252825856 topic 5 : 表示 0.005761316511780024 信息 0.005761316511780024 人们 0.005761316511780024 认为 0.005761316511780024 接受 0.005761316511780024 时 0.005761316511780024 人 0.005761316511780024 没有 0.005761316511780024 最高 0.005761316511780024 过热 0.0030178327579051256 余 0.0030178327579051256 亩 0.0030178327579051256 工程 0.0030178327579051256 系列 0.0030178327579051256 行业 0.0030178327579051256 必须有 0.0030178327579051256 空间 0.0030178327579051256 则 0.0030178327579051256 二次 0.0030178327579051256 专家 0.0030178327579051256 topic 6 : 实力 0.008062418550252914 已经 0.008062418550252914 不同 0.008062418550252914 资金 0.005461638327687979 大量 0.005461638327687979 比 0.005461638327687979 成为 0.005461638327687979 质量 0.005461638327687979 略有 0.005461638327687979 相当 0.005461638327687979 成功 0.005461638327687979 高度 0.005461638327687979 盘 0.005461638327687979 来看 0.005461638327687979 看到 0.005461638327687979 数据 0.005461638327687979 大 0.005461638327687979 越来越多 0.005461638327687979 楼 0.005461638327687979 投资 0.0028608583379536867 topic 7 : 以 0.009867629036307335 nbsp 0.0074608903378248215 曼 0.0074608903378248215 桢 0.0074608903378248215 7 0.005054151173681021 2 0.005054151173681021 其 0.005054151173681021 300 0.005054151173681021 就是 0.005054151173681021 他 0.005054151173681021 又 0.005054151173681021 半生 0.005054151173681021 缘 0.005054151173681021 香港 0.005054151173681021 她也 0.005054151173681021 世 0.005054151173681021 璐 0.005054151173681021 祝 0.005054151173681021 鸿 0.005054151173681021 文 0.005054151173681021 topic 8 : 在 0.016857441514730453 小 0.012695109471678734 这 0.010613943450152874 袁 0.010613943450152874 电话 0.010613943450152874 上海 0.008532778359949589 东莞 0.008532778359949589 总部 0.006451612804085016 没有 0.006451612804085016 他 0.006451612804085016 大学生 0.006451612804085016 设立 0.006451612804085016 随后 0.006451612804085016 才 0.006451612804085016 广东 0.004370447248220444 不少 0.004370447248220444 依然 0.004370447248220444 回 0.004370447248220444 该公司 0.004370447248220444 15日 0.004370447248220444 topic 9 : 旅游 0.016091953963041306 游客 0.01432360801845789 解析 0.009018567390739918 五一 0.009018567390739918 接待 0.009018567390739918 增长 0.009018567390739918 再次 0.0072502209804952145 黄金周 0.0072502209804952145 南京 0.0072502209804952145 里 0.0072502209804952145 人次 0.0072502209804952145 景点 0.0072502209804952145 也 0.005481874104589224 以上 0.005481874104589224 已经 0.005481874104589224 数据 0.005481874104589224 今年 0.005481874104589224 同期 0.005481874104589224 周边 0.005481874104589224 景区 0.005481874104589224
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值