困惑度(Perplexity):评价语言模型的指标
1.定义
PPL(Perplexity) 是用在自然语言处理领域(NLP)中,衡量语言模型好坏的指标。它主要是根据每个词来估计一句话出现的概率,并用句子长度作normalize。
- 其本质上就是计算句子的概率,例如对于句子S(词语w的序列):
S = W 1 , W 2 , W 3 , . . . , W k S = W_1,W_2,W_3,...,W_k S=W1,W2,W3,...,Wk
- 它的概率为:
P ( S ) = P ( W 1 , W 2 , W 3 , . . . , W k ) = p ( W 1 ) p ( W 2 ∣ W 1 ) . . . p ( W k ∣ W 1 , W 2 , W 3 , . . . , W k − 1 ) P(S) = P(W_1,W_2,W_3,...,W_k)= p(W_1)p(W_2|W_1)...p(W_k|W_1,W_2,W_3,...,W_{k-1}) P(S)=P(W1,W2,W3,...,Wk)=p(W1)p(W2∣W1)...p(Wk∣W1,W2,W3,...,Wk−1)
困惑度与测试集上的句子概率相关,其基本思想是:给测试集的句子赋予较高概率值的语言模型较好,当语言模型训练完之后,测试集中的句子都是正常的句子,那么训练好的模型就是在测试集上的概率越高越好。
- 通俗点来讲,假设词库里有10个单词,那么对于一个完全没有训练过的模型,其预测一个特定单词的概率就是1/10,概率是均等分的,这时候我们就能得出其困惑度为10,也就是模式是完全糊涂的,没有任何分辨能力。但是当模型能将一个特定单词预测出1/2的概率时,就代表模型能从10个单词中挑选出2个可能对的单词,这时候模型的困惑度就是2,说明模型有了一定的分辨能力。当然,这么简单的求倒数获取困惑度的前提是概率是均等的,如果概率不均等,那么困惑度和预测的倒数就不是相等关系了。
- 当然,最好的就是模型能识别出那个正确的单词,给予100%的概率,这时候模型的困惑度就是1,代表模型没有任何困惑,是完全清楚的,可以正确识别单词,也就是能正确识别一个句子。
2.公式
下面讲一下其基础公式:
P
P
(
W
)
=
P
(
w
1
w
2
w
3
.
.
.
w
N
)
−
1
N
=
1
P
(
w
1
w
2
w
3
.
.
.
w
N
)
N
PP(W)=P(w_1w_2w_3...w_N)^{-\frac{1}{N}}\\ = \sqrt[N]{\frac{1}{P(w_1w_2w_3...w_N)}}
PP(W)=P(w1w2w3...wN)−N1=NP(w1w2w3...wN)1
这里补充一下公式的细节:
-
根号内是句子概率的倒数,所以显然 句子越好(概率大),困惑度越小,也就是模型对句子越不困惑。 这样我们也就理解了这个指标的名字。
-
开N次根号(N为句子长度)意味着几何平均数(把句子概率拆成字符概率的连乘)
-
-
需要平均的原因是,因为每个字符的概率必然小于1,所以越长的句子的概率在连乘的情况下必然越小,所以为了对长短句公平,需要平均一下
-
是几何平均的原因,是因为其的特点是,如果有其中的一个概率是很小的,那么最终的结果就不可能很大,从而要求好的句子的每个字符都要有基本让人满意的概率 [2]
-
- 机器翻译常用指标BLEU也使用了几何平均,还有机器学习常用的F score使用的调和平均数 ,也有类似的效果。
-
当然,这是在数学领域内计算困惑度的公式,在实际的代码层面,用的是另一套公式,需要将上述公式进行转换,下面我就详细来介绍一下:
- 在真实的代码计算中,上述的公式很难计算,但是就是有大佬发现,其实上述的公式可以转化为求交叉熵的公式。而背后的原理是,不管是困惑度,还是交叉熵,其本质上都是在计算信息熵,所以都是在计算模型的混乱程度,因此两者在数学意义的转换就有了理论依据,下面看一下公式转换过程:
P P ( W ) = 2 H ( W ) = 2 − 1 N log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) PP(W) = 2^{H(W)}=2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)} PP(W)=2H(W)=2−N1log2P(w1,w2,w3,...,wN)
P P ( W ) = 2 − 1 N log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) = ( 2 log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) ) − 1 N = P ( w 1 , w 2 , w 3 , . . . , w N ) − 1 N = 1 P ( w 1 w 2 w 3 . . . w N ) N PP(W)=2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)}\\ =(2^{\log_2P(w_1,w_2,w_3,...,w_N)})^{-\frac{1}{N}}\\ =P(w_1,w_2,w_3,...,w_N)^{-\frac{1}{N}}\\ =\sqrt[N]{\frac{1}{P(w_1w_2w_3...w_N)}} PP(W)=2−N1log2P(w1,w2,w3,...,wN)=(2log2P(w1,w2,w3,...,wN))−N1=P(w1,w2,w3,...,wN)−N1=NP(w1w2w3...wN)1
- 从上面可以看出,
PP(W)
在本质上就是变成了交叉熵加一个底数的指数函数,所以当我们要求困惑度,就可以直接求交叉熵了。 - 这里还有一个细节,这个底数和log是配套的,在公式中间可以直接消掉,所以底数的大小并不重要,这里选了2,换一个我也可以使用e,这并无关系。
3.代码
我们来看一下代码具体是怎么实现困惑度的。
probs = np.take(probs, target, axis=1).diagonal()
total += -np.sum(np.log(probs))
count += probs.size
perplexity = np.exp(total / count)
其实核心代码就这四行。
- 第一行,先求出 P ( w 1 , w 2 , w 3 , . . . , w N ) P(w_1,w_2,w_3,...,w_N) P(w1,w2,w3,...,wN),也就是求交叉熵。
- 第二行,对应的代码是 − log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) -\log_2P(w_1,w_2,w_3,...,w_N) −log2P(w1,w2,w3,...,wN)
- 第三行,对应的代码是求
N
- 第四行,对应的代码就是 2 − 1 N log 2 P ( w 1 , w 2 , w 3 , . . . , w N ) 2^{-\frac{1}{N}\log_2P(w_1,w_2,w_3,...,w_N)} 2−N1log2P(w1,w2,w3,...,wN)
以上就是我个人对于困惑度查询资料以及完成代码之后做出的理解。