刚学习lstm+crf时,就阅读过crf层的源码,发现时间一久,就忘了。这次准备重新阅读一下,顺便做个笔记。主要的目的是深入理解代码细节,提高自己编写模型的能力。本文假定大家对lstm+crf的基本原理基本清楚。如果还不是很清楚,建议看看这篇论文: Bidirectional LSTM-CRF Models for Sequence Tagging 公众号后台回复 crf (全部小写) 即可获得论文。
crf源码的版本
网络上crf的实现版本很多,我选择的是这个: https://github.com/tensorflow/addons/blob/v0.11.2/tensorflow_addons/text/crf.py
概览
首先来看看源码的大概结构:
![6114fe238f09e19477e5bcf719ac468f.png](https://i-blog.csdnimg.cn/blog_migrate/2f6483b15bf1afaad7d714ca6953a2c7.png)
有1个类,11个方法。我们从crf_sequence_score这个方法开始。
阅读代码当然是非常抽象,我们最好在脑子里构建一个例子,阅读的过程中,想象这个例子在被计算。这样就会好理解很多。比如这个:
![c7cc2371db0336aed99ece01b1bd8c61.png](https://i-blog.csdnimg.cn/blog_migrate/eb4bc4bd32d47195f2023b9877ee2f0b.png)
crf_sequence_score
先看签名和注释:
![73d6e27fd5d5ccda88dd75e1fc2d98a5.png](https://i-blog.csdnimg.cn/blog_migrate/800b7b86f6f31c102f12fa1ab6e187d0.png)
inputs就是我们从bilstm层得到的输出。 tag_indices 就是我们要计算的一条标签序列。根据前面的例子,可以想象成(B-ORG O B-MISC O) sequence_lengths 就是句子的真实长度,主要用来计算mask transition_params 就是crf的核心结构,状态转移矩阵
该方法返回tag_indices这个标签序列的得分。
![9862362c5569e71182d9ba525cc92114.png](https://i-blog.csdnimg.cn/blog_migrate/a86f1083d1fd91c257cc66cef91ee84a.png)
这里面内部又定义了两个方法。看注释可以了解,当句子长度是1时,就没必要计算转移分数。直接从inputs中取出对应tag的值即可。也就是这个方法,相应解释也写在代码注释中了:
def _single_seq_fn():
当句子长度大于1时,就必须计算转移分数,这才是正常的状态。方法如下:
def _multi_seq_fn():
可以看到,这里面主要由两个方法组成:
crf_unary_score
crf_binary_score 我们一个一个看。
crf_unary_score
这个方法,是从inputs中,取出tag_indices对应位置的分数,然后把每个句子中的分数相加。
def crf_unary_score(
tag_indices: TensorLike, sequence_lengths: TensorLike, inputs: TensorLike
) -> tf.Tensor:
如果你对上面的reshape有所陌生,提供如下例子供参考:
![e4624bab25e4cac7aee029eb8c84528f.png](https://i-blog.csdnimg.cn/blog_migrate/643112d89a6f14d8f23b863e999e4bd9.png)
好了,今天先到这里吧。下次继续,再长你也不会看了。