【AIGC】因果注意力(Causal Attention)原理及其代码实现

概述

因果注意力(Causal Attention)是一种自注意力机制,广泛应用于自回归模型中,尤其是在自然语言处理和时间序列预测等任务中。它的核心思想是在生成每个时间步的输出时,只关注当前时间步及之前的时间步,确保生成过程的因果性,从而避免模型在预测时依赖未来的信息。

工作原理

因果注意力的工作原理是通过掩码矩阵限制模型在计算每个时间步的注意力时,只关注当前时间步及之前的内容。具体地,掩码矩阵是一个下三角矩阵,其上三角部分为0,其余部分为1。这样,在计算注意力分布时,掩码矩阵将未来时间步的注意力得分设置为非常大的负值(-inf),使得这些位置在 softmax 操作后接近于零,从而不会对最终的输出产生影响。

掩码矩阵示例

掩码矩阵的结构如下:

[
 [1, 0, 0, 0],
 [1, 1, 0, 0],
 [1, 1, 1, 0],
 [1, 1, 1, 1]
]

该掩码矩阵确保每个时间步仅关注当前时间步及之前的时间步,维持因果性。

NumPy实现

以下是基于NumPy的因果注意力机制实现代码:

import numpy as np

def softmax(x):
    """Compute the softmax of vector x in a numerically stable way."""
    shift_x = x - np.max(x, axis=-1, keepdims=True)
    exp_x = np.exp
### 因果注意力机制与稀疏注意力机制 #### 因果注意力机制 因果注意力Causal Attention),也被称为单向注意力,在某些特定的任务中至关重要。这种类型的注意力只允许当前时刻的信息关注到过去的时间步,而不能看到未来的信息。这在自然语言处理中的序列建模尤为有用,因为句子的理解通常是顺序性的——理解某个词时仅能依据之前的上下文[^4]。 对于实现而言,因果注意力可以通过修改标准自注意力建筑来达成。具体来说,在计算注意力分数矩阵之前加入掩码操作,使得任何来自未来的贡献都被屏蔽掉。这样做的好处是在保持了模型性能的同时还保留了一定程度上的解释性和可控性。 ```python import torch import torch.nn as nn class CausalSelfAttention(nn.Module): def __init__(self, d_model, n_heads): super().__init__() self.d_model = d_model self.n_heads = n_heads # 定义线性变换层用于生成QKV self.qkv_proj = nn.Linear(d_model, 3 * d_model) def forward(self, x, mask=None): B, T, C = x.size() qkv = self.qkv_proj(x).reshape(B, T, self.n_heads, 3*C//self.n_heads) q, k, v = qkv.chunk(3, dim=-1) attn_scores = torch.einsum('bthc,bshc->bhts', [q,k]) / math.sqrt(C/self.n_heads) if mask is not None: attn_scores += (mask[:,None,:,None].float() * -1e9) # 应用因果掩码 attn_probs = F.softmax(attn_scores, dim=-1) out = torch.einsum('bhts,bshc->bthc', [attn_probs,v]).contiguous().view(B,T,C) return out ``` #### 稀疏注意力机制 相比之下,稀疏注意力旨在减少传统全连接注意力带来的高计算成本。通过引入某种形式的选择性聚焦策略,使每个位置只需与其他少量选定的位置建立联系即可完成有效信息传递。这种方法特别适合于那些输入长度非常大但内部结构存在局部关联特性的场景下使用[^3]。 然而值得注意的是,并不是所有的运算都可以轻易地变得稀疏化;例如常见的softmax函数就难以做到这一点。因此,在设计具体的算法框架时往往需要综合考量多种因素以达到最佳平衡点。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值