记录一些transformer包中自带的函数功能

因为本人基础不是很好,所以对函数中的每个方法都会有详细的解释

def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

此函数主要用于生成因果掩码,输入中的past_key_values_length为过去键值的长度,输入张量的形状input_ids_shape为batch_size和sqlen

1.创建一个形状为sqlen*sqlen的mask矩阵,用finfo填充

2.创建一个mask_cond张量,形状为(0,sqlen-1)

3.在mask中创建因果掩码

4.如果过去键值长度>0,则在mask矩阵之前加一个(sqlen,past_key_values_length)的全0矩阵

5.返回形状为(bsz, 1, sqlen, sqlen + past_key_values_length)

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

是一个用于扩展注意力掩码的函数

此函数中tgt_len为目标序列的长度,src_len为掩码长度(源序列长度)

将mask掩码转换为[bsz, 1, tgt_len, src_len]形状,此方法中运用了反转掩码实现了因果掩码,不使用反转掩码也是可以达成这个效果的,我在gpt上得到的回答是反转后1表示无效,0表示有效,更直观的表示了哪些位置是可以允许的,哪些是不允许的

上述两个函数的功能就是生成因果掩码以及扩展因果掩码

  • 10
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值