插播一条广告:LDA预测代码阅读

</pre>代码:<p>https://github.com/madlib/madlib/blob/master/src/modules/lda/lda.cpp</p><p>里面有lda的预测功能。</p><p></p><p><pre name="code" class="cpp">/**
 * @brief This function samples a new topic for a word in a document based on
 * the topic counts computed on the rest of the corpus. This is the core
 * function in the Gibbs sampling inference algorithm for LDA.
 * @param topic_num     The number of topics
 * @param topic         The current topic assignment to a word
 * @param count_d_z     The document topic counts
 * @param count_w_z     The word topic counts
 * @param count_z       The corpus topic counts
 * @param alpha         The Dirichlet parameter for the per-doc topic
 *                      multinomial
 * @param beta          The Dirichlet parameter for the per-topic word
 *                      multinomial
 * @return retopic      The new topic assignment to the word
 * @note The topic ranges from 0 to topic_num - 1.
 *
 * @note For the sake of performance, this function will not check the validity
 * of parameters. The caller will ensure that the three pointers all have non-null
 * values and the lengths are the actual lengths of the arrays. And this
 * function is local to this file only, so this function cannot be maliciously
 * called by intruders.
 **/
static int32_t __lda_gibbs_sample(
    int32_t topic_num, int32_t topic, const int32_t * count_d_z, const int32_t * count_w_z,
    const int64_t * count_z, double alpha, double beta)
{
    /* The cumulative probability distribution of the topics */
    double * topic_prs = new double[topic_num];
    if(!topic_prs)
        throw std::runtime_error("out of memory");

    /* Calculate topic (unnormalised) probabilities */
    double total_unpr = 0;
    for (int32_t i = 0; i < topic_num; i++) {
        int32_t nwz = count_w_z[i];
        int32_t ndz = count_d_z[i];
        int64_t nz = count_z[i];

        /* Adjust the counts to exclude current word's contribution */
        if (i == topic) {
            nwz--;
            ndz--;
            nz--;
        }

        /* Compute the probability */
        // Note that ndz, nwz, nz are non-negative, and topic_num, alpha, and
        // beta are positive, so the division by zero will not occure here.
        double unpr =
            (ndz + alpha) * (static_cast<double>(nwz) + beta)
                / (static_cast<double>(nz) + topic_num * beta);
        total_unpr += unpr;
        topic_prs[i] = total_unpr;
    }

    /* Normalise the probabilities */
    // Note that the division by zero will not occure here, so no need to check
    // whether total_unpr is zero
    for (int32_t i = 0; i < topic_num; i++)
        topic_prs[i] /= total_unpr;

    /* Draw a topic at random */
    double r = drand48();
    int32_t retopic = 0;
    while (true) {
        if (retopic == topic_num - 1 || r < topic_prs[retopic])
            break;
        retopic++;
    }

    delete[] topic_prs;
    return retopic;
}
/**
 * @brief This function learns the topics of words in a document and is the
 * main step of a Gibbs sampling iteration. The word topic counts and
 * corpus topic counts are passed to this function in the first call and
 * then transfered to the rest calls through args.mSysInfo->user_fctx for
 * efficiency.
 * @param args[0]   The unique words in the documents
 * @param args[1]   The counts of each unique words
 * @param args[2]   The topic counts and topic assignments in the document
 * @param args[3]   The model (word topic counts and corpus topic
 *                  counts)
 * @param args[4]   The Dirichlet parameter for per-document topic
 *                  multinomial, i.e. alpha
 * @param args[5]   The Dirichlet parameter for per-topic word
 *                  multinomial, i.e. beta
 * @param args[6]   The size of vocabulary
 * @param args[7]   The number of topics
 * @param args[8]   The number of iterations (=1:training, >1:prediction)
 * @return          The updated topic counts and topic assignments for
 *                  the document
 **/
AnyType lda_gibbs_sample::run(AnyType & args)
{
    ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
    ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
    MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
    double alpha = args[4].getAs<double>();
    double beta = args[5].getAs<double>();
    int32_t voc_size = args[6].getAs<int32_t>();
    int32_t topic_num = args[7].getAs<int32_t>();
    int32_t iter_num = args[8].getAs<int32_t>();
    size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);

    if(alpha <= 0)
        throw std::invalid_argument("invalid argument - alpha");
    if(beta <= 0)
        throw std::invalid_argument("invalid argument - beta");
    if(voc_size <= 0)
        throw std::invalid_argument(
            "invalid argument - voc_size");
    if(topic_num <= 0)
        throw std::invalid_argument(
            "invalid argument - topic_num");
    if(iter_num <= 0)
        throw std::invalid_argument(
            "invalid argument - iter_num");

    if(words.size() != counts.size())
        throw std::invalid_argument(
            "dimensions mismatch: words.size() != counts.size()");
    if(__min(words) < 0 || __max(words) >= voc_size)
        throw std::invalid_argument(
            "invalid values in words");
    if(__min(counts) <= 0)
        throw std::invalid_argument(
            "invalid values in counts");

    int32_t word_count = __sum(counts);
    if(doc_topic.size() != (size_t)(word_count + topic_num))
        throw std::invalid_argument(
            "invalid dimension - doc_topic.size() != word_count + topic_num");
    if(__min(doc_topic, 0, topic_num) < 0)
        throw std::invalid_argument("invalid values in topic_count");
    if(
        __min(doc_topic, topic_num, word_count) < 0 ||
        __max(doc_topic, topic_num, word_count) >= topic_num)
        throw std::invalid_argument( "invalid values in topic_assignment");

    if (!args.getUserFuncContext()) {
        ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
        if (model64.size() != model64_size) {
            std::stringstream ss;
            ss << "invalid dimension: model64.size() = " << model64.size();
            throw std::invalid_argument(ss.str());
        }
        if (__min(model64) < 0) {
            throw std::invalid_argument("invalid topic counts in model");
        }

        int32_t *context =
            static_cast<int32_t *>(
                MemoryContextAllocZero(
                    args.getCacheMemoryContext(),
                    model64.size() * sizeof(int64_t)
                        + topic_num * sizeof(int64_t)));
        memcpy(context, model64.ptr(), model64.size() * sizeof(int64_t));
        int32_t *model = context;

        int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
                context + model64_size * sizeof(int64_t) / sizeof(int32_t));
        for (int i = 0; i < voc_size; i ++) {
            for (int j = 0; j < topic_num; j ++) {
                running_topic_counts[j] += model[i * (topic_num + 1) + j];
            }
        }

        args.setUserFuncContext(context);
    }

    int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
    if (context == NULL) {
        throw std::runtime_error("args.mSysInfo->user_fctx is null");
    }
    int32_t *model = context;
    int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
            context + model64_size * sizeof(int64_t) / sizeof(int32_t));

    int32_t unique_word_count = static_cast<int32_t>(words.size());
    for(int32_t it = 0; it < iter_num; it++){
        int32_t word_index = topic_num;
        for(int32_t i = 0; i < unique_word_count; i++) {
            int32_t wordid = words[i];
            for(int32_t j = 0; j < counts[i]; j++){
                int32_t topic = doc_topic[word_index];
                int32_t retopic = __lda_gibbs_sample(
                    topic_num, topic, doc_topic.ptr(),
                    model + wordid * (topic_num + 1),
                    running_topic_counts, alpha, beta);
                doc_topic[word_index] = retopic;
                doc_topic[topic]--;
                doc_topic[retopic]++;

                if(iter_num == 1) {
                    if (model[wordid * (topic_num + 1) + retopic] <= 2e9) {
                        running_topic_counts[topic] --;
                        running_topic_counts[retopic] ++;
                        model[wordid * (topic_num + 1) + topic]--;
                        model[wordid * (topic_num + 1) + retopic]++;
                    } else {
                        model[wordid * (topic_num + 1) + topic_num] = 1;
                    }
                }
                word_index++;
            }
        }
    }

    return doc_topic;
}



https://github.com/madlib/madlib/blob/master/src/modules/lda/lda.cpp


关键步骤如上。主要有函数是__lda_gibbs_sample。 run是主入口。

1、gibbs抽样分析:

   /* Normalise the probabilities */
    // Note that the division by zero will not occure here, so no need to check
    // whether total_unpr is zero
    // 注释: 这里每一个每一个topic都有一个概率值,但是会让人疑惑的是 每个topic越往后概率值越大,最后一个肯定为1.
    // 因为:<span style="color: rgb(51, 51, 51); font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; line-height: 18.2px; white-space: pre; background-color: rgb(240, 240, 240);">total_unpr += unpr;</span>
<span style="color: rgb(51, 51, 51); font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; line-height: 18.2px; white-space: pre; background-color: rgb(240, 240, 240);"><span style="color: rgb(51, 51, 51); font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; line-height: 18.2px; white-space: pre; background-color: rgb(240, 240, 240);">    //      topic_prs[i] = total_unpr;</span></span>
    for (int32_t i = 0; i < topic_num; i++)
        topic_prs[i] /= total_unpr;

    /* Draw a topic at random */
    // 这里是消除上述问题的关键,会随机得到一个[0,1)的小数r,这里找到第一个大于r的概率topic。 这里保证了一定的随机性。
    double r = drand48();
    int32_t retopic = 0;
    while (true) {
        if (retopic == topic_num - 1 || r < topic_prs[retopic])
            break;
        retopic++;
    }


2、run的关键分析:

可以看到doc_topic[word_index] = retopic; 又有:doc_topic[topic]--;

可以看到doc_topic里保存了topic对应的值,又保存了word对应的topic类别。 

if(doc_topic.size() != (size_t)(word_count + topic_num))

这个联合前面可以看出是:前面部分是topic对应的, 后面是word对应topic分布。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值