</pre>代码:<p>https://github.com/madlib/madlib/blob/master/src/modules/lda/lda.cpp</p><p>里面有lda的预测功能。</p><p></p><p><pre name="code" class="cpp">/**
* @brief This function samples a new topic for a word in a document based on
* the topic counts computed on the rest of the corpus. This is the core
* function in the Gibbs sampling inference algorithm for LDA.
* @param topic_num The number of topics
* @param topic The current topic assignment to a word
* @param count_d_z The document topic counts
* @param count_w_z The word topic counts
* @param count_z The corpus topic counts
* @param alpha The Dirichlet parameter for the per-doc topic
* multinomial
* @param beta The Dirichlet parameter for the per-topic word
* multinomial
* @return retopic The new topic assignment to the word
* @note The topic ranges from 0 to topic_num - 1.
*
* @note For the sake of performance, this function will not check the validity
* of parameters. The caller will ensure that the three pointers all have non-null
* values and the lengths are the actual lengths of the arrays. And this
* function is local to this file only, so this function cannot be maliciously
* called by intruders.
**/
static int32_t __lda_gibbs_sample(
int32_t topic_num, int32_t topic, const int32_t * count_d_z, const int32_t * count_w_z,
const int64_t * count_z, double alpha, double beta)
{
/* The cumulative probability distribution of the topics */
double * topic_prs = new double[topic_num];
if(!topic_prs)
throw std::runtime_error("out of memory");
/* Calculate topic (unnormalised) probabilities */
double total_unpr = 0;
for (int32_t i = 0; i < topic_num; i++) {
int32_t nwz = count_w_z[i];
int32_t ndz = count_d_z[i];
int64_t nz = count_z[i];
/* Adjust the counts to exclude current word's contribution */
if (i == topic) {
nwz--;
ndz--;
nz--;
}
/* Compute the probability */
// Note that ndz, nwz, nz are non-negative, and topic_num, alpha, and
// beta are positive, so the division by zero will not occure here.
double unpr =
(ndz + alpha) * (static_cast<double>(nwz) + beta)
/ (static_cast<double>(nz) + topic_num * beta);
total_unpr += unpr;
topic_prs[i] = total_unpr;
}
/* Normalise the probabilities */
// Note that the division by zero will not occure here, so no need to check
// whether total_unpr is zero
for (int32_t i = 0; i < topic_num; i++)
topic_prs[i] /= total_unpr;
/* Draw a topic at random */
double r = drand48();
int32_t retopic = 0;
while (true) {
if (retopic == topic_num - 1 || r < topic_prs[retopic])
break;
retopic++;
}
delete[] topic_prs;
return retopic;
}
/**
* @brief This function learns the topics of words in a document and is the
* main step of a Gibbs sampling iteration. The word topic counts and
* corpus topic counts are passed to this function in the first call and
* then transfered to the rest calls through args.mSysInfo->user_fctx for
* efficiency.
* @param args[0] The unique words in the documents
* @param args[1] The counts of each unique words
* @param args[2] The topic counts and topic assignments in the document
* @param args[3] The model (word topic counts and corpus topic
* counts)
* @param args[4] The Dirichlet parameter for per-document topic
* multinomial, i.e. alpha
* @param args[5] The Dirichlet parameter for per-topic word
* multinomial, i.e. beta
* @param args[6] The size of vocabulary
* @param args[7] The number of topics
* @param args[8] The number of iterations (=1:training, >1:prediction)
* @return The updated topic counts and topic assignments for
* the document
**/
AnyType lda_gibbs_sample::run(AnyType & args)
{
ArrayHandle<int32_t> words = args[0].getAs<ArrayHandle<int32_t> >();
ArrayHandle<int32_t> counts = args[1].getAs<ArrayHandle<int32_t> >();
MutableArrayHandle<int32_t> doc_topic = args[2].getAs<MutableArrayHandle<int32_t> >();
double alpha = args[4].getAs<double>();
double beta = args[5].getAs<double>();
int32_t voc_size = args[6].getAs<int32_t>();
int32_t topic_num = args[7].getAs<int32_t>();
int32_t iter_num = args[8].getAs<int32_t>();
size_t model64_size = static_cast<size_t>(voc_size * (topic_num + 1) + 1) * sizeof(int32_t) / sizeof(int64_t);
if(alpha <= 0)
throw std::invalid_argument("invalid argument - alpha");
if(beta <= 0)
throw std::invalid_argument("invalid argument - beta");
if(voc_size <= 0)
throw std::invalid_argument(
"invalid argument - voc_size");
if(topic_num <= 0)
throw std::invalid_argument(
"invalid argument - topic_num");
if(iter_num <= 0)
throw std::invalid_argument(
"invalid argument - iter_num");
if(words.size() != counts.size())
throw std::invalid_argument(
"dimensions mismatch: words.size() != counts.size()");
if(__min(words) < 0 || __max(words) >= voc_size)
throw std::invalid_argument(
"invalid values in words");
if(__min(counts) <= 0)
throw std::invalid_argument(
"invalid values in counts");
int32_t word_count = __sum(counts);
if(doc_topic.size() != (size_t)(word_count + topic_num))
throw std::invalid_argument(
"invalid dimension - doc_topic.size() != word_count + topic_num");
if(__min(doc_topic, 0, topic_num) < 0)
throw std::invalid_argument("invalid values in topic_count");
if(
__min(doc_topic, topic_num, word_count) < 0 ||
__max(doc_topic, topic_num, word_count) >= topic_num)
throw std::invalid_argument( "invalid values in topic_assignment");
if (!args.getUserFuncContext()) {
ArrayHandle<int64_t> model64 = args[3].getAs<ArrayHandle<int64_t> >();
if (model64.size() != model64_size) {
std::stringstream ss;
ss << "invalid dimension: model64.size() = " << model64.size();
throw std::invalid_argument(ss.str());
}
if (__min(model64) < 0) {
throw std::invalid_argument("invalid topic counts in model");
}
int32_t *context =
static_cast<int32_t *>(
MemoryContextAllocZero(
args.getCacheMemoryContext(),
model64.size() * sizeof(int64_t)
+ topic_num * sizeof(int64_t)));
memcpy(context, model64.ptr(), model64.size() * sizeof(int64_t));
int32_t *model = context;
int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
context + model64_size * sizeof(int64_t) / sizeof(int32_t));
for (int i = 0; i < voc_size; i ++) {
for (int j = 0; j < topic_num; j ++) {
running_topic_counts[j] += model[i * (topic_num + 1) + j];
}
}
args.setUserFuncContext(context);
}
int32_t *context = static_cast<int32_t *>(args.getUserFuncContext());
if (context == NULL) {
throw std::runtime_error("args.mSysInfo->user_fctx is null");
}
int32_t *model = context;
int64_t *running_topic_counts = reinterpret_cast<int64_t *>(
context + model64_size * sizeof(int64_t) / sizeof(int32_t));
int32_t unique_word_count = static_cast<int32_t>(words.size());
for(int32_t it = 0; it < iter_num; it++){
int32_t word_index = topic_num;
for(int32_t i = 0; i < unique_word_count; i++) {
int32_t wordid = words[i];
for(int32_t j = 0; j < counts[i]; j++){
int32_t topic = doc_topic[word_index];
int32_t retopic = __lda_gibbs_sample(
topic_num, topic, doc_topic.ptr(),
model + wordid * (topic_num + 1),
running_topic_counts, alpha, beta);
doc_topic[word_index] = retopic;
doc_topic[topic]--;
doc_topic[retopic]++;
if(iter_num == 1) {
if (model[wordid * (topic_num + 1) + retopic] <= 2e9) {
running_topic_counts[topic] --;
running_topic_counts[retopic] ++;
model[wordid * (topic_num + 1) + topic]--;
model[wordid * (topic_num + 1) + retopic]++;
} else {
model[wordid * (topic_num + 1) + topic_num] = 1;
}
}
word_index++;
}
}
}
return doc_topic;
}
https://github.com/madlib/madlib/blob/master/src/modules/lda/lda.cpp
关键步骤如上。主要有函数是__lda_gibbs_sample。 run是主入口。
1、gibbs抽样分析:
/* Normalise the probabilities */
// Note that the division by zero will not occure here, so no need to check
// whether total_unpr is zero
// 注释: 这里每一个每一个topic都有一个概率值,但是会让人疑惑的是 每个topic越往后概率值越大,最后一个肯定为1.
// 因为:<span style="color: rgb(51, 51, 51); font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; line-height: 18.2px; white-space: pre; background-color: rgb(240, 240, 240);">total_unpr += unpr;</span>
<span style="color: rgb(51, 51, 51); font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; line-height: 18.2px; white-space: pre; background-color: rgb(240, 240, 240);"><span style="color: rgb(51, 51, 51); font-family: Consolas, 'Liberation Mono', Menlo, Courier, monospace; line-height: 18.2px; white-space: pre; background-color: rgb(240, 240, 240);"> // topic_prs[i] = total_unpr;</span></span>
for (int32_t i = 0; i < topic_num; i++)
topic_prs[i] /= total_unpr;
/* Draw a topic at random */
// 这里是消除上述问题的关键,会随机得到一个[0,1)的小数r,这里找到第一个大于r的概率topic。 这里保证了一定的随机性。
double r = drand48();
int32_t retopic = 0;
while (true) {
if (retopic == topic_num - 1 || r < topic_prs[retopic])
break;
retopic++;
}
可以看到doc_topic[word_index] = retopic; 又有:doc_topic[topic]--;
可以看到doc_topic里保存了topic对应的值,又保存了word对应的topic类别。
if(doc_topic.size() != (size_t)(word_count + topic_num))
这个联合前面可以看出是:前面部分是topic对应的, 后面是word对应topic分布。