LGN 小记
原文链接
https://www.aclweb.org/anthology/D19-1096.pdf
主要思路
参照了Lattice的思想,考虑将Lexicon的信息融入到标签的推理过程中,但是与之不同的是这里是用了图的思想构建图神经网络,作者说是使用了图网络的通用框架Message passing。通过将lexicon中匹配到的单词word作为图中的边,将原本的sentence word序列作为图中的节点,构建了一张图网络。
整个图网络的计算过程按论文中所述为aggregation->update->aggregation的迭代。图中包含的另外几个关键节点分别是e->c的ci,c->e的ei,以及全局聚合的g。aggregation阶段会通过节点信息流向边,以及边信息流向节点,以及边和结点流向global来聚合信息。aggregation主要用的是multi-att,注意论文中虽然描述的是ci和ei,这里之前一直没看明白,以为是利用att将信息聚合到一个结点/边的状态上,其实在计算的时候是对每个i的位置都做了att计算,正如self-att那样通过矩阵进行联合的计算。而multi只不过是增加了几个Wq,Wk,Wv矩阵,最后通过Wo进行head_num个信息的加权。
而update阶段则是基于LSTM中gate的思想分别对每个边和节点进行计算,实现的时候当然还是所有边一起进行矩阵计算。
过程中的tips
准备工作和lattice基本一致,其中的alphabet构建统计训练集中所有的出现的字,同时会有一个trie tree作为word_dict用来存储lexicon中的单词,通过trie进行recursive search搜索匹配字典的数据集中出现的单词,存到另一个alphabet中。
生成的时候build_with_gaz就是逐句扫描,将一句句子逐字的将字的id和标签放入列表,一句句子结束后通过递归扫描(这里需要注意是怎么扫的,它是从第i个位置开始到末尾为止,逐渐减少末尾位置进行扫描,所以不会出现重复的),同时还需要注意的是word_Ids中存的是[matched_Id, matched_length]两项。
后面batchify的过程就是拆分,放入Variable中,同时计算mask,这个mask这里没什么用,因为batch_size是设为1的。
Graph构建与更新
init中是首先根据是否bi-direction,是否global等设置,预先初始化好不同的module,比如MultiHeadAtt,GloAtt,node_rnn_f, edge_rnn_b等,这里multi由于会进行迭代所以设置为一个iter个的ModuleList
具体在forward的时候会根据输入进行graph_update,由于每个batch的lexicon会不同,所以graph_update会先调用construct_graph。
construct_graph会根据word_list设置:
-
batch_word_embed:edge_embs, [1, 9, 50] ,字典中匹配到的单词对应的word_embedding信息,作为edge嵌入
-
batch_bmes_embed:[1, 9, 21, 10],对应位置bmes的embedding
-
batch_nodes_mask:[1, 9, 21],node_mask,每一个word对应的字位置置为0
-
batch_words_mask_f & batch_words_mask_b:[1, 9, 21],words_mask_f and words_mask_b used to sign the beginning location from forward direction and backward direction, which are used in the aggregation edge2node
-
batch_words_length:[1, 9, 10],对每一个word进行length嵌入的拼接,第一个是创建好的unk字符占的位
-
edges_mask:[1, 9]:对curr_edge_num以内的位置置为1,考虑的是batch中不同句子的edge数量不同,但这里batch_size为1,所以没什么用
MultiAttention
这里参考了star-transformer的代码,计算wq,wk的时候没有用linear,而是用了conv2d,但是卷积核设为1,将通道扩宽成了num_head*head_num,实际上是一样的,这么做猜测是为了效率。后面的q·k也没有用matmul,而是通过* 进行