因为本人基础不是很好,所以对函数中的每个方法都会有详细的解释
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz, tgt_len = input_ids_shape
mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
mask_cond = torch.arange(mask.size(-1), device=device)
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
mask = mask.to(dtype)
if past_key_values_length > 0:
mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)
此函数主要用于生成因果掩码,输入中的past_key_values_length为过去键值的长度,输入张量的形状input_ids_shape为batch_size和sqlen
1.创建一个形状为sqlen*sqlen的mask矩阵,用finfo填充
2.创建一个mask_cond张量,形状为(0,sqlen-1)
3.在mask中创建因果掩码
4.如果过去键值长度>0,则在mask矩阵之前加一个(sqlen,past_key_values_length)的全0矩阵
5.返回形状为(bsz, 1, sqlen, sqlen + past_key_values_length)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
是一个用于扩展注意力掩码的函数
此函数中tgt_len为目标序列的长度,src_len为掩码长度(源序列长度)
将mask掩码转换为[bsz, 1, tgt_len, src_len]形状,此方法中运用了反转掩码实现了因果掩码,不使用反转掩码也是可以达成这个效果的,我在gpt上得到的回答是反转后1表示无效,0表示有效,更直观的表示了哪些位置是可以允许的,哪些是不允许的
上述两个函数的功能就是生成因果掩码以及扩展因果掩码