作用解释
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
inverted_mask = 1.0 - expanded_mask
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
这段代码定义了一个函数 _expand_mask
,其目的是将一个形状为[bsz, seq_len]
的掩码张量扩展(或者说变形)为形状为[bsz, 1, tgt_seq_len, src_seq_len]
的张量。这在Transformer模型中的自注意力机制中是常见的操作,因为我们需要对每个头(head)、每个目标位置(tgt position)和每个源位置(src position)都有一个掩码值
参数介绍
mask
: 一个torch.Tensor
对象,表示原始的掩码。它的形状应该是[bsz, seq_len]
。dtype
: 一个torch.dtype
对象,表示生成的掩码的数据类型。tgt_len
: 可选的整数,表示目标序列的长度。如果没有提供,那么就默认等于源序列的长度。
(Optional
是一个类型提示(Type hint),表示变量的值可以是指定的类型,或者可以是None
。Optional
是typing
模块提供的一个特殊类型。
在这个例子中,Optional[int]
表示tgt_len
可以是一个整数(int
),也可以是None
。
= None
则是在定义函数时为参数tgt_len
设置的默认值。如果在调用函数时没有提供tgt_len
参数的值,那么它的值就会被设置为None
。
所以,tgt_len: Optional[int] = None
的意思是:tgt_len
参数的值可以是一个整数,如果在调用函数时没有提供这个参数的值,那么它的值将会被设置为None
。)
逐行解释
bsz, src_len = mask.size()
获取mask
的形状,分别赋值给bsz
和src_len
tgt_len = tgt_len if tgt_len is not None else src_len
如果提供了tgt_len
,那么就使用这个值;否则,将src_len
赋值给tgt_len
。
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)
将mask
的形状从[bsz, seq_len]
扩展为[bsz, 1, tgt_len, src_len]
。这是通过添加两个新的维度(使用None
索引)并使用.expand()
方法来完成的。然后将扩展后的mask
转换为dtype
指定的数据类型
inverted_mask = 1.0 - expanded_mask
计算inverted_mask
,它是expanded_mask
的每个元素都被减去1后得到的。这样,原来的0变为了1,原来的1变为了0
inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)
使用.masked_fill()
方法将inverted_mask
中值为True
的位置(也就是原来mask
中值为0的位置)填充为torch.finfo(dtype).min
,这是dtype
能表示的最小浮点数。这样做的目的是为了在自注意力计算中,将这些位置的权重几乎变为0,从而实现掩码的效果。
因为注意力mask中的1通常表示需要注意的位置,而在计算注意力时,为了避免不需要的位置的影响,会将这些位置的注意力值设置为非常小的负数。这样,在softmax操作时,这些位置的注意力值会接近于零,相当于没有被注意到。因此,这个函数返回的是一个经过扩展和处理的注意力mask,用于在计算注意力时限制模型的注意力范围。