用于Transformer的6种注意力的数学原理和代码实现

Transformer 的出色表现让注意力机制出现在深度学习的各处。本文整理了深度学习中最常用的6种注意力机制的数学原理和代码实现。

1、Full Attention

2017的《Attention is All You Need》中的编码器-解码器结构实现中提出。它结构并不复杂,所以不难理解。

上图 1.左侧显示了 Scaled Dot-Product Attention 的机制。当我们有多个注意力时,我们称之为多头注意力(右),这也是最常见的注意力的形式公式如下:

公式1

这里Q(Query)、K(Key)和V(values)被认为是它的输入,dₖ(输入维度)被用来降低复杂度和计算成本。这个公式可以说是深度学习中注意力机制发展的开端。下面我们看一下它的代码:

class FullAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
    def forward(self, queries, keys, values, attn_mask):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1. / sqrt(E)
        scores = torch.einsum("blhe,bshe->bhls", queries, keys)
        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)
        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)
        if self.output_attention:
            return (V.contiguous(), A)
        else:
            return (V.contiguous(), None)

2、ProbSparse Attention

借助“Transformer Dissection: A Unified Understanding of Transformer’s Attention via the lens of Kernel”中的信息我们可以将公式修改为下面的公式2。第i个query的attention就被定义为一个概率形式的核平滑方法(kernel smoother):

公式2

从公式 2,我们可以定义第 i 个查询的稀疏度测量如下:

最后,注意力块的最终公式是下面的公式4。

代码如下:

class ProbAttention(nn.Module):
    def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
        super(ProbAttention, self).__init__()
        self.factor = factor
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)
    def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
        # Q [B, H, L, D]
        B, H, L_K, E = K.shape
        _, _, L_Q, _ = Q.shape
        # calculate the sampled Q_K
        K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
        index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
        K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
        Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
        # find the Top_k query with sparisty measurement
        M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
        M_top = M.topk(n_top, sorted=False)[1]
        # use the reduced Q to calculate Q_K
        Q_reduce = Q[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], M_top, :] # factor*ln(L_q)
        Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
        return Q_K, M_top
    def _get_initial_context(self, V, L_Q):
        B, H, L_V, D = V.shape
        if not self.mask_flag:
            # V_sum = V.sum(dim=-2)
            V_sum = V.mean(dim=-2)
            contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
        else: # use mask
            assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
            contex = V.cumsum(dim=-2)
        return contex
    def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
        B, H, L_V, D = V.shape
        if self.mask_flag:
            attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
            scores.masked_fill_(attn_mask.mask, -np.inf)

        attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)

        context_in[torch.arange(B)[:, None, None],
                   torch.arange(H)[None, :, None],
                   index, :] = torch.matmul(attn, V).type_as(context_in)
        if self.output_attention:
            attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
            attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
            return (context_in, attns)
        else:
            return (context_in, None)
    def forward(self, queries, keys, values, attn_mask):
        B, L_Q, H, D = queries.shape
        _, L_K, _, _ = keys.shape

        queries = queries.transpose(2,1)
        keys = keys.transpose(2,1)
        values = values.transpose(2,1)

        U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
        u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q) 

        U_part = U_part if U_part<L_K else L_K
        u = u if u<L_Q else L_Q
        
        scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u) 

        # add scale factor
        scale = self.scale or 1./sqrt(D)
        if scale is not None:
            scores_top = scores_top * scale
        # get the context
        context = self._get_initial_context(values, L_Q)
        # update the context with selected top_k queries
        context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
        
        return context.transpose(2,1).contiguous(), attn

这也是 Informer这个用于长序列时间序列预测的新型Transformer中使用的注意力。

3、LogSparse Attention

我们之前讨论的注意力有两个缺点:1. 与位置无关 2. 内存的瓶颈。为了应对这两个问题,研究人员使用了卷积算子和 LogSparse Transformers。

Transformer 中相邻层之间不同注意力机制的图示

卷积自注意力显示在(右)中,它使用步长为 1,内核大小为 k 的卷积层将输入(具有适当的填充)转换为Q/K。这种位置感知可以根据(左)中的形状正确匹配最相关的特征

他们不是使用步长为 1,卷积核大小 1,而是使用步长为 1 ,核大小为 k 的随意卷积(以确保模型无法访问未来的点)将输入转换为Q和K

代码实现

class Attention(nn.Module):
    def __init__(self, n_head, n_embd, win_len, scale, q_len, sub_len, sparse=None, attn_pdrop=0.1, resid_pdrop=0.1):
        super(Attention, self).__init__()

        if(sparse):
            print('Activate log sparse!')
            mask = self.log_mask(win_len, sub_len)
        else:
            mask = torch.tril(torch.ones(win_len, win_len)).view(1, 1, win_len, win_len)

        self.register_buffer('mask_tri', mask)
        self.n_head = n_head
        self.split_size = n_embd * self.n_head
        self.scale = scale
        self.q_len = q_len
        self.query_key = nn.Conv1d(n_embd, n_embd * n_head * 2, self.q_len)
        self.value = Conv1D(n_embd * n_head, 1, n_embd)
        self.c_proj = Conv1D(n_embd, 1, n_embd * self.n_head)
        self.attn_dropout = nn.Dropout(attn_pdrop)
        self.resid_dropout = nn.Dropout(resid_pdrop)

    def log_mask(self, win_len, sub_len):
        mask = torch.zeros((win_len, win_len), dtype=torch.float)
        for i in range(win_len):
            mask[i] = self.row_mask(i, sub_len, win_len)
        return mask.view(1, 1, mask.size(0), mask.size(1))

    def row_mask(self, index, sub_len, win_len):
        """
        Remark:
        1 . Currently, dense matrices with sparse multiplication are not supported by Pytorch. Efficient implementation
            should deal with CUDA kernel, which we haven't implemented yet.
        2 . Our default setting here use Local attention and Restart attention.
        3 . For index-th row, if its past is smaller than the number of cells the last
            cell can attend, we can allow current cell to attend all past cells to fully
            utilize parallel computing in dense matrices with sparse multiplication."""
        log_l = math.ceil(np.log2(sub_len))
        mask = torch.zeros((win_len), dtype=torch.float)
        if((win_len // sub_len) * 2 * (log_l) > index):
            mask[:(index + 1)] = 1
        else:
            while(index >= 0):
                if((index - log_l + 1) < 0):
                    mask[:index] = 1
                    break
                mask[index - log_l + 1:(index + 1)] = 1  # Local attention
                for i in range(0, log_l):
                    new_index = index - log_l + 1 - 2**i
                    if((index - new_index) <= sub_len and new_index >= 0):
                        mask[new_index] = 1
                index -= sub_len
        return mask

    def attn(self, query: torch.Tensor, key, value: torch.Tensor, activation="Softmax"):
        activation = activation_dict[activation](dim=-1)
        pre_att = torch.matmul(query, key)
        if self.scale:
            pre_att = pre_att / math.sqrt(value.size(-1))
        mask = self.mask_tri[:, :, :pre_att.size(-2), :pre_att.size(-1)]
        pre_att = pre_att * mask + -1e9 * (1 - mask)
        pre_att = activation(pre_att)
        pre_att = self.attn_dropout(pre_att)
        attn = torch.matmul(pre_att, value)

        return attn

    def merge_heads(self, x):
        x = x.permute(0, 2, 1, 3).contiguous()
        new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
        return x.view(*new_x_shape)

    def split_heads(self, x, k=False):
        new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
        x = x.view(*new_x_shape)
        if k:
            return x.permute(0, 2, 3, 1)
        else:
            return x.permute(0, 2, 1, 3)

    def forward(self, x):

        value = self.value(x)
        qk_x = nn.functional.pad(x.permute(0, 2, 1), pad=(self.q_len - 1, 0))
        query_key = self.query_key(qk_x).permute(0, 2, 1)
        query, key = query_key.split(self.split_size, dim=2)
        query = self.split_heads(query)
        key = self.split_heads(key, k=True)
        value = self.split_heads(value)
        attn = self.attn(query, key, value)
        attn = self.merge_heads(attn)
        attn = self.c_proj(attn)
        attn = self.resid_dropout(attn)
        return attn

class Conv1D(nn.Module):
    def __init__(self, out_dim, rf, in_dim):
        super(Conv1D, self).__init__()
        self.rf = rf
        self.out_dim = out_dim
        if rf == 1:
            w = torch.empty(in_dim, out_dim)
            nn.init.normal_(w, std=0.02)
            self.w = Parameter(w)
            self.b = Parameter(torch.zeros(out_dim))
        else:
            raise NotImplementedError

    def forward(self, x):
        if self.rf == 1:
            size_out = x.size()[:-1] + (self.out_dim,)
            x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w)
            x = x.view(*size_out)
        else:
            raise NotImplementedError
        return x

来自:https://github.com/AIStream-Peelout/flow-forecast/

4、LSH Attention

Reformer的论文选择了局部敏感哈希的angular变体。它们首先约束每个输入向量的L2范数(即将向量投影到一个单位球面上),然后应用一系列的旋转,最后找到每个旋转向量所属的切片。这样一来就需要找到最近邻的值,这就需要局部敏感哈希(LSH)了,它能够快速在高维空间中找到最近邻。一个局部敏感哈希算法可以将每个向量 x 转换为 hash h(x),和这个 x 靠近的哈希更有可能有着相同的哈希值,而距离远的则不会。作者希望最近的向量最可能得到相同的哈希值,或者 hash-bucket 大小相似的更有可能相同。

局部敏感哈希算法使用球投影点的随机旋转,通过argmax在有符号的轴投影上建立bucket。在这个高度简化的2D描述中,对于三个不同的角哈希,两个点x和y不太可能共享相同的哈希桶(上图),除非它们的球面投影彼此接近(下图)。

通过固定一个大小为 [dₖ, b/2] 的随机矩阵 R 来获得 b 个哈希值。h(x) = argmax([xR;-xR]) 其中 [u;v] 表示两个向量的串联。这样就可以使用LSH,将查询位置I带入重写公式1:

下图我们可以示意性地解释 LSH 注意力:

  • 原始的注意力矩阵通常是稀疏的,但不利于计算
  • LSH Attention基于哈希桶进行键的排序进行查询
  • 在排序后的注意矩阵中,来自同一桶的对将聚集在对角线附近
  • 采用批处理方法,m个连续查询的块相互处理,一个块返回。

代码很长为了节约时间这里就不贴了:

https://github.com/lucidrains/reformer-pytorch/blob/master/reformer_pytorch/reformer_pytorch.py

5、Sparse Attention(Generating Long Sequences with Sparse Transformers)

OpenAI的Sparse Attention,通过“只保留小区域内的数值、强制让大部分注意力为零”的方式,来减少Attention的计算量。通过top-k选择,将注意退化为稀疏注意。这样,保留最有助于引起注意的部分,并删除其他无关的信息。这种选择性方法在保存重要信息和消除噪声方面是有效的。注意力可以更多地集中在最有贡献的价值因素上。

代码:https://github.com/openai/sparse_attention

6、Single-Headed Attention(Single Headed Attention RNN: Stop Thinking With Your Head)

SHA-RNN模型的注意力是简化到只保留了一个头并且唯一的矩阵乘法出现在query (下图Q) 那里,A是缩放点乘注意力 (Scaled Dot-Product Attention) ,是向量之间的运算。所以这种计算量比较小,能够快速的进行训练,就像它介绍的那样:

Obtain strong results on a byte level language modeling dataset (enwik8) in under 24 hours on a single GPU (12GB Titan V)

代码:https://github.com/Smerity/sha-rnn

引用:

  1. Kitaev, N., Ł. Kaiser, and A. Levskaya, Reformer: The efficient transformer. arXiv preprint arXiv:2001.04451, 2020.
  2. Li, S., et al., Enhancing the locality and breaking the memory bottleneck of transformer on time series forecasting. Advances in Neural Information Processing Systems, 2019. 32.
  3. Zhou, H., et al. Informer: Beyond efficient transformer for long sequence time-series forecasting. in Proceedings of AAAI. 2021.
  4. Vaswani, A., et al., Attention is all you need. Advances in neural information processing systems, 2017. 30.
  5. Rewon Child, Scott Gray, Alec Radford, Ilya Sutskever Generating Long Sequences with Sparse Transformers
  6. Stephen Merity ,Single Headed Attention RNN: Stop Thinking With Your Head

注意机制的发展到现在远远不止这些,在本篇文章中只整理了一些常见的注意力机制,希望对你有所帮助。

另外就是来自Erasmus University的Gianni Brauwers 和Flavius Frasincar在TKDE上发表的《A General Survey on Attention Mechanisms in Deep Learning》综述论文,提供了一个关于深度学习注意力机制的重要概述。各种注意力机制通过一个由注意力模型,统一符号和一个全面的分类注意力机制组成的框架来进行解释,还有注意力模型评价的各种方法。

有兴趣和有资源的话可以进行阅读,女神Alexandra Elbakyan的网站还未提供该论文。

https://www.overfit.cn/post/739299d8be4e4ddc8f5804b37c6c82ad

作者:Reza Yazdanfar

<think>好的,我现在需要回答用户关于Transformer架构及其多头注意力机制的问题。首先,我得回忆一下Transformer的基本结构,它主要由编码器解码器组成,每个部分都有多个层。每个编码器层包含多头注意力机制前馈神经网络,而解码器层除了这两个还有编码器-解码器注意力层。 接下来是重点,用户特别提到了多头注意力机制的工作原理。我记得自注意力机制是基础,通过计算查询、键、值向量的相关性来分配权重。多头则是将这个过程分成多个“头”,每个头在不同的子空间进行注意力计算,最后将结果合并。这样能捕捉不同方面的信息,比如远近依赖或不同语法结构。 不过,具体细节可能需要更仔细的梳理。例如,每个头的维度是原始维度除以头数,这样可以保持计算量不变。然后,每个头独立进行缩放点积注意力,生成各自的输出,再将这些输出拼接起来,通过线性变换得到最终结果。 需要确保提到的数学公式正确,比如缩放点积注意力的公式:$\text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d_k}})V$。每个头的处理步骤也要清晰,包括线性变换、拆分、注意力计算、拼接再次线性变换。 可能用户还会想知道多头注意力的优势,比如并行处理多样化特征捕获。这部分需要结合应用场景,比如机器翻译中处理不同语言结构的能力。此外,代码示例能帮助理解,比如展示如何拆分张量,计算多个头,再合并结果。 在引用方面,用户提供的参考是之前关于多头注意力的博文,我需要自然地将其融入回答中,比如在解释多头机制时提到该引用。同时,生成的相关问题要围绕Transformer的组成部分、自注意力多头注意力的区别,以及实际应用中的优化策略。 最后,要检查是否符合格式要求:行内公式用$...$,独立公式用$$...$$,结构清晰,使用中文,没有Markdown,引用标识正确放置。确保所有技术细节准确无误,逻辑连贯,帮助用户逐步理解Transformer多头注意力机制的工作原理。</think>### Transformer架构与多头注意力机制的工作原理 #### 1. Transformer架构概述 Transformer是一种基于自注意力机制深度学习模型,广泛用于自然语言处理任务(如机器翻译、文本生成)。其核心结构包含**编码器****解码器**,每部分由多层堆叠组成: - **编码器层**:多头自注意力机制 + 前馈神经网络 - **解码器层**:掩码多头自注意力机制 + 编码器-解码器注意力层 + 前馈神经网络 所有层均通过残差连接层归一化优化训练稳定性[^1]。 #### 2. 多头注意力机制的核心原理 **(1)自注意力机制基础** - 输入序列的每个词生成**查询向量$Q$、键向量$K$、值向量$V$** - 计算注意力权重: $$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$$ 其中$d_k$是向量维度,$\sqrt{d_k}$用于缩放梯度[^1]。 **(2)多头拆分与并行计算** 1. **线性投影**:将$Q/K/V$通过多组权重矩阵映射到不同子空间 $$Q_i = Q W_i^Q, \quad K_i = K W_i^K, \quad V_i = V W_i^V$$ 2. **独立注意力计算**:每个“头”单独执行缩放点积注意力 3. **拼接与融合**:将多个头的输出拼接后通过线性层整合 $$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_h) W^O$$ **(3)优势** - **并行性**:多个头可并行计算,提升效率 - **多样化特征捕获**:不同头关注不同距离或类型的依赖关系(如语法结构、语义关联) #### 3. 代码示例(简化版) ```python import torch import torch.nn as nn class MultiHeadAttention(nn.Module): def __init__(self, d_model, num_heads): super().__init__() self.d_model = d_model self.num_heads = num_heads self.head_dim = d_model // num_heads # 定义权重矩阵 self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) def forward(self, Q, K, V): batch_size = Q.size(0) # 线性投影并拆分为多头 Q = self.W_Q(Q).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) K = self.W_K(K).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) V = self.W_V(V).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) # 计算缩放点积注意力 scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim)) attn_weights = torch.softmax(scores, dim=-1) output = torch.matmul(attn_weights, V) # 拼接多头结果并输出 output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) return self.W_O(output) ``` #### 4. 应用场景 - **机器翻译**:捕捉长距离跨语言依赖 - **文本摘要**:识别关键句间关联 - **图像生成**:像素级全局关系建模(如Vision Transformer
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值