国庆节疯玩了7天,感觉整个人已经成为一条咸鱼了,在上班前一晚把attention机制复习了一下,就当是收收心了(感到羞愧)。
首先实现一些函数,称为attention_utils.py
import numpy as np
import torch
import torch.nn.functional as F
def create_src_lengths_mask(batch_size, src_lengths):
'''
生成布尔掩码以防止注意力超出source的末尾
:param batch_size: int
:param src_lengths: [batch_size] 每个句子的实际长度
:return: [batch_size, max_src_len]
'''
max_src_len = src_lengths.max()
# [1, max_src_len]
src_indices = torch.arange(0, max_src_len).unsqueeze(0).type_as(src_lengths)
# [batch_size, max_src_len]
src_indices = src_indices.expand(batch_size, max_src_len)
src_lengths = src_lengths.unsqueeze(1).expand(batch_size, max_src_len)
# 小于实际长度的为1,大于的为0,detach截断反向梯度传播
return (src_indices < src_lengths).int().detach()
def masked_softmax(scores, src_lengths, src_length_masking=True):
'''
先生成mask,然后再进行softmax。
'''
if src_length_masking:
batch_size, max_src_len = scores.size()
# compute masks
src_mask = create_src_lengths_mask(batch_size, src_lengths)
# Fill pad positions with -inf
scores = scores.masked_fill(src_mask == 0, -np.inf)
# 转换为float16,然后再次转换回来以防止loss爆炸
return F.softmax(scores.float(), dim=-1).type_as(scores)
然后实现一个无attention的基础类,base_attention.py
import