Causal mask代码阅读


文章参考[https://blog.csdn.net/BIT_666/article/details/133174206],如有侵权联系立马删除
先上代码

def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
    return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length)

作用解释

Causal Mask 主要用于限定模型的可视范围,防止模型看到未来的数据

当我们做生成任务的时候,我们也想对生成的这个单词做注意力计算,但是,生成的句子是一个一个单词生成的
I have a dream

I 第一次注意力计算,只有 I

I have 第二次,只有 I 和 have

I have a

I have a dream

I have a dream <eos>

掩码自注意力机制应运而生

代码解释

参数说明

函数接收四个参数:

  • input_ids_shape: 输入张量的形状,通常是一个二维的形状,包含批次大小和序列长度。
  • dtype: 张量的数据类型,如torch.float32。
  • device: 计算设备,如CPU或GPU。
  • past_key_values_length: 过去键值对的长度,这是用于transformer模型的自注意力机制,用于补齐。

逐行解释

bsz, tgt_len = input_ids_shape

首先,从input_ids_shape中获取批次大小(bsz)和目标序列长度(tgt_len

torch.full()

mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min, device=device), device=device)
torch.full()函数解释
torch.full(size, fill_value, *, dtype=None, device=None, requires_grad=False)

参数说明:
size:张量的形状,可以是一个整数或者一个元组,例如:(3, 3) 或 3
fill_value:张量的填充值。
dtype:张量的数据类型,默认为None,即根据输入的数据类型推断。
device:张量所在的设备,默认为None,即根据输入的设备推断。
requires_grad:是否需要计算梯度,默认为False。
该函数返回一个与指定形状相同且所有元素都被设置为指定填充值的新张量。

函数使用

torch.tensor(torch.finfo(dtype).min是指指定类型的最小值

代码中生成的效果就是
在这里插入图片描述

torch.arange()

行代码是使用torch.arange函数创建一个一维张量,其元素是从0开始到mask.size(-1) - 1的整数序列。mask.size(-1)返回mask张量最后一个维度的大小。

torch.arange函数的参数是结束点,产生的序列不包含这个结束点。例如,torch.arange(5)将生成一个包含[0, 1, 2, 3, 4]的张量。

device=device是指定生成的张量在哪个设备上(CPU或GPU)。这需要与其他操作在同一设备上,否则会出现错误。

所以,如果mask的形状是(10, 10),那么mask_cond将是一个包含[0, 1, 2, …, 9]的一维张量,且与mask在同一设备上。

torch.view()

torch.view()函数解释

在 PyTorch 库中,view()函数用于改变一个张量(Tensor)的形状(shape)。它返回一个新的张量,其元素与原始张量相同,但形状(shape)已被改变。view 函数的行为非常类似于 NumPy的 reshape 函数。它会返回一个与原始张量共享数据但具有不同形状的新的张量。如果给定的形状与原始张量的元素总数不匹配,则会引发错误。

(在PyTorch中,view()函数用于改变张量的形状。当你传入-1作为参数,PyTorch将自动计算该维度的大小,以保证新的形状与原始张量中的元素总数相匹配。)

import torch  
# 传入-1作为参数,PyTorch将自动计算该维度的大小
x = torch.randn(4, 5)  # 创建一个4x5的随机张量  
y = x.view(20)  # 改变形状为20的一维张量  
z = x.view(-1, 10)  # 改变形状为10的一维张量,第一维度由其他维度决定
函数使用
mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
  1. mask_cond + 1: 这会给mask_cond中的每个元素加1。例如,如果mask_cond是[0, 1, 2, 3, 4],那么mask_cond + 1就是[1, 2, 3, 4, 5]。

  2. (mask_cond + 1).view(mask.size(-1), 1): 这会将mask_cond + 1重新塑形为列向量。例如,如果mask_cond + 1是[1, 2, 3, 4, 5],那么view操作后的结果是一个5x1的二维张量,如下:

[
 [1],
 [2],
 [3],
 [4],
 [5]
]
  1. mask_cond < (mask_cond + 1).view(mask.size(-1), 1): 这会比较mask_cond(mask_cond + 1).view(mask.size(-1), 1),并返回一个布尔张量,其元素是比较结果。由于mask_cond是一维的,而(mask_cond + 1).view(mask.size(-1), 1)是二维的,所以会进行广播(broadcasting),即将mask_cond复制成与(mask_cond + 1).view(mask.size(-1), 1)同形状的张量,然后再进行比较。
    广播的结果是将mask_cond复制成与(mask_cond + 1).view(mask.size(-1), 1)同形状的张量,即:
[
 [0, 1, 2, 3, 4],
 [0, 1, 2, 3, 4],
 [0, 1, 2, 3, 4],
 [0, 1, 2, 3, 4],
 [0, 1, 2, 3, 4]
]

比较的结果是:

[
 [ True, False, False, False, False],
 [ True,  True, False, False, False],
 [ True,  True,  True, False, False],
 [ True,  True,  True,  True, False],
 [ True,  True,  True,  True,  True]
]
  1. mask.masked_fill_(...): 这会将mask中对应布尔张量为True的位置填充为0。由于masked_fill_是原地操作(in-place operation),所以mask会被直接修改,不会创建新的张量。
[
 [0, inf, inf, inf, inf],
 [0,   0, inf, inf, inf],
 [0,   0,   0, inf, inf],
 [0,   0,   0,   0, inf],
 [0,   0,   0,   0,   0]
]

其中,inf表示浮点数的最大值,0表示被填充的位置。

torch.fill()函数解释(插)

在 PyTorch 库中,masked_fill_() 函数是一个张量(Tensor)方法,用于将张量中的指定区域填充为特定值。此函数需要一个掩码(mask)作为输入,该掩码应与原张量具有相同的形状。掩码中的 True 值表示需要填充的区域,False 值表示需要保留的原始值。

torch.Tensor.masked_fill_(mask, value)
参数说明:
mask (Bool tensor) - 掩码张量,用于指定需要填充的区域。
value (float) - 填充的值。

import torch  
  
# 创建一个3x3的张量  
x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])  
  
# 创建一个掩码,其中大于5的元素为True,其余为False  
mask = x > 5  
  
# 使用masked_fill_函数将大于5的元素替换为-1  
x.masked_fill_(mask, -1)  
  
print(x)
tensor([[ 1,  2,  3],  
        [ 4,  5,  6],  
        [-1, -1, -1]])

根据 mask 对 target x target 的方阵进行填充 0 得到我们上面提到的倒三角:
在这里插入图片描述

mask.to()

mask = mask.to(dtype)这行代码的作用是将mask张量的数据类型转换为dtype指定的数据类型。

past_key_values_length

在 PyTorch 中,past_key_values_length 是一个参数,用于指定在使用 Transformer 模型时,过去键值缓存(past key-value cache)的长度。该参数通常与 Transformer 模型中的自注意力机制(self-attention mechanism)一起使用。在过去键值缓存中,模型保存了过去的键和值向量,以便在生成序列时重复使用它们。这些过去的键和值向量可以用于计算自注意力分数,从而提高生成序列的效率。较大的past_key_values_length可以增加模型的表现力,但也会增加计算量和内存消耗。因此,需要根据具体任务和资源限制来选择合适的值。

if past_key_values_length > 0:
        mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1)
  • if past_key_values_length > 0::这个条件判断是否有过去的键值对需要考虑。如果past_key_values_length大于0,说明有过去的键值对,那么就需要执行以下的代码。

  • torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device): 这行代码创建了一个全零的张量,形状为(tgt_len, past_key_values_length),数据类型和设备与之前的mask相同。

  • torch.cat([..., mask], dim=-1): 这行代码将新创建的全零张量和原来的mask拼接在一起。这里dim=-1表示在最后一个维度上进行拼接。也就是说,新的掩码的宽度(第二个维度)等于past_key_values_length和原始mask的宽度之和。

在这里插入图片描述
这里的拼接顺序是先全零张量,后mask。也就是说,全零张量会被添加到原有mask的左侧

这样做的目的是,对于过去的键值对,我们不需要对其进行掩码操作,所以这部分的掩码是全零的;对于当前的键值对,我们需要阻止模型查看未来的信息,所以这部分的掩码是下三角形状的。两部分拼接在一起,就构成了完整的掩码。!!!

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值