指针生成网络
由article生成abstract,生成词表(tgt_vocab)的大小是有限的,因此加入指针机制解决oov问题。
article2ids
def article2ids(article_words, vocab):
ids = []
oovs = []
unk_id = vocab.word2id(UNKNOWN_TOKEN)
for w in article_words:
i = vocab.word2id(w)
if i == unk_id: # If w is OOV
if w not in oovs: # Add to list of OOVs
oovs.append(w)
oov_num = oovs.index(w) # This is 0 for the first article OOV, 1 for the second article OOV...
ids.append(vocab.size() + oov_num) # This is e.g. 50000 for the first article OOV, 50001 for the second...
else:
ids.append(i)
return ids, oovs
这里相当于用OOV临时扩展了生成词表,每出现一个OOV词汇,即将其附加在生成词表后,然后索引+1,最终形成新的。
abstract2ids
def abstract2ids(abstract_words, vocab, article_oovs):
ids = []
unk_id = vocab.word2id(UNKNOWN_TOKEN)
for w in abstract_words:
i = vocab.word2id(w)
if i == unk_id: # If w is an OOV word
if w in article_oovs: # If w is an in-article OOV
vocab_idx = vocab.size() + article_oovs.index(w) # Map to its temporary article OOV number
ids.append(vocab_idx)
else: # If w is an out-of-article OOV
ids.append(unk_id) # Map to the UNK token id
else:
ids.append(i)
return ids
解码过程
final_dist = vocab_dist_.scatter_add(dim=1, index=enc_batch_extend_vocab, src=attn_dist_)
其中的enc_batch_extend_vocab
是由
self.enc_batch_extend_vocab, self.article_oovs = data.article2ids(article_words, vocab)
得到的,即包含了OOV的扩展词表。