plda源码(十二)
LightLDA
原始 Gibbs Sampling 采样函数如下:
p
(
z
d
i
=
k
∣
r
e
s
t
)
∝
(
n
k
d
−
d
i
+
α
k
)
(
n
k
w
−
d
i
+
β
w
)
n
k
−
d
i
+
β
‾
p(z_{di}=k | rest) ∝ \frac{(n^{−di}_{kd}+\alpha_k)(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}}
p(zdi=k∣rest)∝nk−di+β(nkd−di+αk)(nkw−di+βw)
AliasLDA
p
(
z
d
i
=
k
∣
r
e
s
t
)
∝
n
k
d
−
d
i
(
n
k
w
−
d
i
+
β
w
)
n
k
−
d
i
+
β
‾
+
α
k
(
n
k
w
−
d
i
+
β
w
)
n
k
−
d
i
+
β
‾
p(z_{di}=k | rest) ∝ \frac{n^{−di}_{kd}(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}} + \frac{\alpha_k(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}}
p(zdi=k∣rest)∝nk−di+βnkd−di(nkw−di+βw)+nk−di+βαk(nkw−di+βw)
第二项可以看做“topic-word”桶,与文档无关。这一项可以通过Alias Table和 Metropolis-Hastings(一种蒙特卡洛采样方法) 进行O(1) 时间复杂度采样。Alias Table在上一篇文章有介绍。
LightLDA
p
(
z
d
i
=
k
∣
r
e
s
t
)
∝
(
n
k
d
−
d
i
+
α
k
)
∗
(
n
k
w
−
d
i
+
β
w
)
n
k
−
d
i
+
β
‾
p(z_{di}=k | rest) ∝ (n^{−di}_{kd}+\alpha_k) * \frac{(n^{−di}_{kw}+\beta_w)}{n^{−di}_k+\overline{\beta}}
p(zdi=k∣rest)∝(nkd−di+αk)∗nk−di+β(nkw−di+βw)
q
(
z
d
i
=
k
∣
r
e
s
t
)
∝
(
n
k
d
+
α
k
)
∗
n
k
w
+
β
w
n
k
+
β
‾
q(z_{di}=k | rest) \propto (n_{kd} + \alpha_{k}) * \frac{n_{kw} + \beta_w}{n_k + \overline{\beta}}
q(zdi=k∣rest)∝(nkd+αk)∗nk+βnkw+βw
第一项为doc-proposal,第二项为word-proposal。
同样退化成MH采样
m
i
n
{
1
,
p
(
t
)
q
(
t
→
s
)
p
(
s
)
q
(
s
→
t
)
}
min\{ 1, \frac{p(t)q(t \rightarrow s)}{p(s)q(s\rightarrow t)} \}
min{1,p(s)q(s→t)p(t)q(t→s)}
doc-proposal
q
=
p
d
(
k
)
∝
n
k
d
+
α
k
q = p_d(k) \propto n_{kd}+\alpha_k
q=pd(k)∝nkd+αk
接受率
π
d
=
(
n
t
d
−
d
i
+
α
t
)
(
n
t
w
−
d
i
+
β
w
)
(
n
s
−
d
i
+
β
‾
)
(
n
s
d
+
α
s
)
(
n
s
d
−
d
i
+
α
s
)
(
n
s
w
−
d
i
+
β
w
)
(
n
t
−
d
i
+
β
‾
)
(
n
t
d
+
α
t
)
\pi_d = \frac{ (n^{−di}_{td}+\alpha_t)(n^{−di}_{tw}+\beta_w)(n^{−di}_s+\overline{\beta})(n_{sd}+\alpha_s)}{ (n^{−di}_{sd}+\alpha_s)(n^{−di}_{sw}+\beta_w)(n^{−di}_t+\overline{\beta})(n_{td}+\alpha_t)}
πd=(nsd−di+αs)(nsw−di+βw)(nt−di+β)(ntd+αt)(ntd−di+αt)(ntw−di+βw)(ns−di+β)(nsd+αs)
int K = model_->num_topics();
double sumPd = document->GetDocumentLength() + Kalpha;
for (...) {
int w = iterator.Word();
int topic = iterator.Topic();
int new_topic;
int old_topic = topic;
{
// Draw a topic from doc-proposal
double u = random->RandDouble() * sumPd;
if (u < document->GetDocumentLength()) {
// draw from doc-topic distribution skipping n
unsigned pos = (unsigned) (u);
new_topic = document->topics().wordtopics(pos);
} else {
// draw uniformly
u -= document->GetDocumentLength();
u = u / alpha_;
new_topic = (unsigned short) (u); // pick_a_number(0,trngdata->docs[m]->length-1); (int)(utils::unif01()*ptrndata->docs[m]->length);
}
if (topic != new_topic) {
//2. Find acceptance probability
int ajustment_old = topic == old_topic? -1 : 0;
int ajustment_new = new_topic == old_topic? -1 : 0;
double temp_old = ComputeProbForK(document, w, topic, ajustment_old);
double temp_new = ComputeProbForK(document, w, new_topic, ajustment_new);
double prop_old = (N_DK(document, topic) + alpha_);
double prop_new = (N_DK(document, new_topic) + alpha_);
double acceptance = (temp_new * prop_old) / (temp_old * prop_new);
//3. Compare against uniform[0,1]
if (random->RandDouble() < acceptance) {
topic = new_topic;
}
}
其中的ComputeProbForK是
double ComputeProbForK(LDADocument* document, int w, int topic,
int ajustment) {
return (N_DK(document, topic) + alpha_ + ajustment)
* (N_WK(w, topic) + beta_ + ajustment)
/ (N_K(topic) + Vbeta + ajustment);
}
word-proposal
q
=
p
d
(
k
)
∝
n
k
w
+
β
w
n
k
+
β
‾
q = p_d(k) \propto \frac{n_{kw} + \beta_w}{n_k + \overline{\beta}}
q=pd(k)∝nk+βnkw+βw
接受率
π
w
=
(
n
t
d
−
d
i
+
α
t
)
(
n
t
w
−
d
i
+
β
w
)
(
n
s
−
d
i
+
β
‾
)
(
n
s
w
+
β
w
)
(
n
t
+
β
‾
)
(
n
s
d
−
d
i
+
α
s
)
(
n
s
w
−
d
i
+
β
w
)
(
n
t
−
d
i
+
β
‾
)
(
n
t
w
+
β
w
)
(
n
s
+
β
‾
)
\pi_w = \frac{ (n^{−di}_{td}+\alpha_t)(n^{−di}_{tw}+\beta_w)(n^{−di}_s+\overline{\beta})(n_{sw} + \beta_w)(n_t + \overline{\beta})}{ (n^{−di}_{sd}+\alpha_s)(n^{−di}_{sw}+\beta_w)(n^{−di}_t+\overline{\beta})(n_{tw} + \beta_w)(n_s + \overline{\beta})}
πw=(nsd−di+αs)(nsw−di+βw)(nt−di+β)(ntw+βw)(ns+β)(ntd−di+αt)(ntw−di+βw)(ns−di+β)(nsw+βw)(nt+β)
{
// Draw a topic from word-proposal
q[w].noSamples++;
if (q[w].noSamples > qtable_construct_frequency) {
GenerateQTable(w);
}
new_topic = q[w].sample(random->RandInt(K), random->RandDouble());
if (topic != new_topic) {
//2. Find acceptance probability
int ajustment_old = topic == old_topic? -1 : 0;
int ajustment_new = new_topic == old_topic? -1 : 0;
double temp_old = ComputeProbForK(document, w, topic, ajustment_old);
double temp_new = ComputeProbForK(document, w, new_topic, ajustment_new);
double acceptance = (temp_new * q[w].w[topic]) / (temp_old * q[w].w[new_topic]);
//3. Compare against uniform[0,1]
if (random->RandDouble() < acceptance) {
topic = new_topic;
}
}
}
其中GenerateQTable如下
void GenerateQTable(unsigned int w) {
int num_topics = model_->num_topics();
q[w].wsum = 0.0;
const TopicDistribution<int32>& word_distribution = model_->GetWordTopicDistribution(w);
const TopicDistribution<int32>& n_k = model_->GetGlobalTopicDistribution();
for (int k = 0; k < num_topics; ++k) {
q[w].w[k] = (word_distribution[k] + beta_) / (n_k[k] + Vbeta);
q[w].wsum += q[w].w[k];
}
q[w].constructTable();
}