_expand_mask代码阅读

作用解释

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

这段代码定义了一个函数 _expand_mask,其目的是将一个形状为[bsz, seq_len]的掩码张量扩展(或者说变形)为形状为[bsz, 1, tgt_seq_len, src_seq_len]的张量。这在Transformer模型中的自注意力机制中是常见的操作,因为我们需要对每个头(head)、每个目标位置(tgt position)和每个源位置(src position)都有一个掩码值

参数介绍

  • mask: 一个torch.Tensor对象,表示原始的掩码。它的形状应该是[bsz, seq_len]
  • dtype: 一个torch.dtype对象,表示生成的掩码的数据类型。
  • tgt_len: 可选的整数,表示目标序列的长度。如果没有提供,那么就默认等于源序列的长度。

Optional是一个类型提示(Type hint),表示变量的值可以是指定的类型,或者可以是NoneOptionaltyping模块提供的一个特殊类型。
在这个例子中,Optional[int]表示tgt_len可以是一个整数(int),也可以是None
= None则是在定义函数时为参数tgt_len设置的默认值。如果在调用函数时没有提供tgt_len参数的值,那么它的值就会被设置为None
所以,tgt_len: Optional[int] = None的意思是:tgt_len参数的值可以是一个整数,如果在调用函数时没有提供这个参数的值,那么它的值将会被设置为None。)

逐行解释

bsz, src_len = mask.size()

获取mask的形状,分别赋值给bszsrc_len

tgt_len = tgt_len if tgt_len is not None else src_len

如果提供了tgt_len,那么就使用这个值;否则,将src_len赋值给tgt_len

expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

mask的形状从[bsz, seq_len]扩展为[bsz, 1, tgt_len, src_len]。这是通过添加两个新的维度(使用None索引)并使用.expand()方法来完成的。然后将扩展后的mask转换为dtype指定的数据类型

inverted_mask = 1.0 - expanded_mask

计算inverted_mask,它是expanded_mask的每个元素都被减去1后得到的。这样,原来的0变为了1,原来的1变为了0

inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

使用.masked_fill()方法将inverted_mask中值为True的位置(也就是原来mask中值为0的位置)填充为torch.finfo(dtype).min,这是dtype能表示的最小浮点数。这样做的目的是为了在自注意力计算中,将这些位置的权重几乎变为0,从而实现掩码的效果。

因为注意力mask中的1通常表示需要注意的位置,而在计算注意力时,为了避免不需要的位置的影响,会将这些位置的注意力值设置为非常小的负数。这样,在softmax操作时,这些位置的注意力值会接近于零,相当于没有被注意到。因此,这个函数返回的是一个经过扩展和处理的注意力mask,用于在计算注意力时限制模型的注意力范围。

  • 9
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值