Transformer核心代码-#notebook

Transformer 核心代码multi-head self attention

transformer 编码器和解码器架构,主要的结构是multi-head self-attention,残差连接,FFN,层归一化。其中核心的代码是关于multi-head self-attention的实现,下面给出一个逐行代码解释。

具体的pytorch实现可以参考此链接

首先回顾一下计算公式:

设输入序列为 x 1 : T ∈ R D h × T \mathbf{x}_{1:T} \in \mathbb{R}^{D_h \times T} x1:TRDh×T,其嵌入表示并添加位置编码表示位置信息,有
H = [ e x 1 + p 1 , ⋯   , e x T + p T ] \mathbf{H}=[e_{\mathbf{x}_1}+p_1,\cdots,e_{\mathbf{x}_T}+p_T] H=[ex1+p1,,exT+pT]

多头自注意力(Multi-HeadSelf-Attention),在多个不同的投影空间中捕捉不同的交互信息.假设在 M M M 个投影空间中分别应用自注意力模型,有
MultiHead ⁡ ( Q , K , V ) = W o [ head ⁡ 1 ; ⋯   ; head ⁡ M ] , head ⁡ m = self-att ⁡ ( Q m , K m , V m ) , ∀ m ∈ { 1 , ⋯   , M } , Q m = W q m Q , K = W k m K , V = W v m V , \begin{aligned} \operatorname{MultiHead}(\mathbf{Q},\mathbf{K},\mathbf{V})&=\boldsymbol{W}_o[\operatorname{head}_1;\cdots;\operatorname{head}_M],\\ \operatorname{head}_m&=\operatorname{self-att}(\boldsymbol{Q}_m,\boldsymbol{K}_m,\boldsymbol{V}_m),\\ \forall m\in\{1,\cdots,M\},\quad \boldsymbol{Q}_m&=\boldsymbol{W}_q^m\boldsymbol{Q},\boldsymbol{K}=\boldsymbol{W}_k^m\boldsymbol{K},\boldsymbol{V}=\boldsymbol{W}_v^m\boldsymbol{V},\end{aligned} MultiHead(Q,K,V)headmm{1,,M},Qm=Wo[head1;;headM],=self-att(Qm,Km,Vm),=WqmQ,K=WkmK,V=WvmV,
其中 W o ∈ R D h × M ⋅ D v \boldsymbol{W}_o \in \mathbb{R}^{D_h \times M \cdot D_v} WoRDh×MDv为输出投影矩阵, W q m ∈ R D k × D h \boldsymbol{W}_q^m \in \mathbb{R}^{D_k \times D_h} WqmRDk×Dh, W k m ∈ R D k × D h \boldsymbol{W}_k^m \in \mathbb{R}^{D_k \times D_h} WkmRDk×Dh, W q m ∈ R D v × D h \boldsymbol{W}_q^m \in \mathbb{R}^{D_v \times D_h} WqmRDv×Dh, 为投影矩阵, m ∈ { 1 , . . . , M } m \in \{ 1,...,M\} m{1,...,M}.

自注意力模型 self-att ⁡ \operatorname{self-att} self-att可以看作在一个线性投影空间中建立 H \mathbf{H} H(自注意力中 Q , K , V 均为 H \mathbf{Q},\mathbf{K},\mathbf{V}均为\mathbf{H} Q,K,V均为H)中不同向量之间的交互关系.其计算公式为
self-att ⁡ ( Q , K , V ) = V softmax ⁡ ( K ⊺ Q D k ) , Q = W q H , K = W k H , V = W υ H , \begin{aligned}\operatorname{self-att}(\boldsymbol{Q},\boldsymbol{K},\boldsymbol{V})&=\boldsymbol{V}\operatorname{softmax}\big(\frac{\boldsymbol{K}^\intercal\boldsymbol{Q}}{\sqrt{D_k}}\big),\\\boldsymbol{Q}&=\boldsymbol{W}_q\mathbf{H},\boldsymbol{K}=\boldsymbol{W}_k\mathbf{H},\boldsymbol{V}=\boldsymbol{W}_\upsilon\mathbf{H},\end{aligned} self-att(Q,K,V)Q=Vsoftmax(Dk KQ),=WqH,K=WkH,V=WυH,

其中 D k D_k Dk是输入矩阵 Q \boldsymbol{Q} Q K \boldsymbol{K} K中列向量的维度, W q ∈ R D k × D h \boldsymbol{W}_q \in \mathbb{R}^{D_k \times D_h} WqRDk×Dh, W k ∈ R D k × D h \boldsymbol{W}_k \in \mathbb{R}^{D_k \times D_h} WkRDk×Dh, W v ∈ R D v × D h \boldsymbol{W}_v \in \mathbb{R}^{D_v \times D_h} WvRDv×Dh为三个投影矩阵。

通过对公式的具体分析,在多头注意力机制的计算中,频繁涉及到keyqueryvalue向量的构造,因此我们需要一个为多头注意力机制准备keyqueryvalue向量的封装函数。

PrepareForMultiHeadAttention

这段代码定义了一个PyTorch模块,PrepareForMultiHeadAttention,其主要作用是为多头注意力机制准备keyqueryvalue向量。这一过程涉及线性变换和将变换后的向量分割成指定数量的头。下面是对代码的逐行解释:

类定义

  • class PrepareForMultiHeadAttention(nn.Module):定义了一个名为PrepareForMultiHeadAttention的类,该类继承自PyTorch的nn.Module类。这表明PrepareForMultiHeadAttention是一个可以集成到PyTorch模型中的自定义模块。

构造函数 __init__

  • def __init__(self, d_model: int, heads: int, d_k: int, bias: bool): 这个构造函数接受四个参数:d_model是输入向量的维度;heads是要分割的头的数量;d_k是分割后每个头中向量的维度;bias是一个布尔值,指定是否在线性变换中添加偏置项。
  • self.linear = nn.Linear(d_model, heads * d_k, bias=bias) 创建一个线性层,用于对输入向量进行线性变换。变换的输出维度是头的数量乘以每个头的维度。也就是对输入表示向量的维度进行切分,切成heads份,每份大小是d_k
  • self.heads = headsself.d_k = d_k 分别存储了头的数量和每个头中向量的维度,这些信息在前向传播时用于重塑变换后的向量。

前向传播 forward

  • def forward(self, x: torch.Tensor): 定义了模块的前向传播逻辑,其中x是输入张量,其形状可以是[seq_len, batch_size, d_model][batch_size, d_model]
  • head_shape = x.shape[:-1] 获取输入张量x除最后一个维度外的形状,这用于后续重塑变换后的向量。
  • x = self.linear(x) 对输入x应用前面定义的线性变换。
  • x = x.view(*head_shape, self.heads, self.d_k) 重塑线性变换后的向量,使其最后两个维度分别为头的数量和每个头的向量维度。这样,输出张量的形状变为[seq_len, batch_size, heads, d_k][batch_size, heads, d_k]

总结

这个模块在多头注意力机制中扮演关键角色,通过对keyqueryvalue向量进行适当的线性变换和重塑,使得它们能够被分配到不同的“头”中。这种分配使得模型能够在不同的表示子空间中并行捕获信息,从而提高了模型处理复杂信息的能力。

输入:

  • d_model:输入向量的维度;
  • heads:要分割的头的数量;
  • d_k:分割后每个头中向量的维度,即 d k = d model / heads d_k = d_{\text{model}}/\text{heads} dk=dmodel/heads
  • bias:一个布尔值,指定是否在线性变换中添加偏置项。

transformer中,Q,K,V的维度大小通常是相等的,因此 d v = d k d_v = d_k dv=dk

输出:输出张量的形状变为[seq_len, batch_size, heads, d_k][batch_size, heads, d_k]

import math
from typing import Optional, List
import torch
from torch import nn
from labml import tracker
class PrepareForMultiHeadAttention(nn.Module):
    """
    ## Prepare for multi-head attention
    This module does a linear transformation and splits the vector into given
    number of heads for multi-head attention.
    This is used to transform **key**, **query**, and **value** vectors.
    该模块进行线性变换,并将向量分割为给定数量的头以进行多头注意力。这用于变换 **key**、**query** 和 **value** 向量。
    """

    def __init__(self, d_model: int, heads: int, d_k: int, bias: bool):
        super().__init__()
        # Linear layer for linear transform
        self.linear = nn.Linear(d_model, heads * d_k, bias=bias)
        # Number of heads
        self.heads = heads
        # Number of dimensions in vectors in each head
        self.d_k = d_k

    def forward(self, x: torch.Tensor):
        # Input has shape `[seq_len, batch_size, d_model]` or `[batch_size, d_model]`.
        # We apply the linear transformation to the last dimension and split that into
        # the heads.
        head_shape = x.shape[:-1]
        # Linear transform
        x = self.linear(x)
        # Split last dimension into heads
        x = x.view(*head_shape, self.heads, self.d_k)
        # Output has shape `[seq_len, batch_size, heads, d_k]` or `[batch_size, heads, d_model]`
        return x

接着根据多头注意力的公式可以逐行实现Multi-Head Attention module,不过在这个过程中需要注意的是,实际计算的输入是[seq_len,batch_size,heads,d_k],因此在计算key和query的乘积时,使用的是torch.enisum()方法,它可以为各种张量运算提供一个简洁的框架,具体解释参考官方文档

另外一个需要注意的点是在解码器中用到的masked multi-head attention,这是通过设置掩码来实现的。

Multi-Head Attention Module

这段代码定义了一个MultiHeadAttention类,它是Transformer模型中多头注意力机制的实现。以下是对该类及其方法的详细解释:

类定义和初始化

  • class MultiHeadAttention(nn.Module):定义了MultiHeadAttention类,继承自PyTorch的nn.Module
  • __init__方法中,类接收几个参数:
    • heads: 多头注意力机制中头的数量。
    • d_model: 输入向量的特征维度,也是querykeyvalue向量的维度。
    • dropout_prob: dropout操作的概率,用于防止过拟合。
    • bias: 是否在PrepareForMultiHeadAttention中使用偏置项。
  • self.d_k = d_model // heads计算每个头的特征维度。
  • self.query, self.key, self.value使用PrepareForMultiHeadAttention类分别对querykeyvalue向量进行线性变换,为多头注意力计算做准备。
  • self.softmax定义了在计算注意力时沿着键(key)的时间()维度应用的softmax函数。
  • self.output是一个线性层,用于将多头注意力的输出重新映射回原始的特征空间d_model
  • self.dropout定义了dropout操作,用于在注意力权重上进行。
  • self.scale是缩放因子,用于在计算softmax之前调整注意力分数,以避免因维度较大而导致的梯度消失或爆炸。
class MultiHeadAttention(nn.Module):
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1, bias: bool = True):
        """
        * `heads` is the number of heads.
        * `d_model` is the number of features in the `query`, `key` and `value` vectors.
        """

        super().__init__()

        # Number of features per head
        self.d_k = d_model // heads
        # Number of heads
        self.heads = heads

        # These transform the `query`, `key` and `value` vectors for multi-headed attention.
        # input:[d_m,heads,d_k]
        # output:[seq_len, batch_size, heads, d_k]
        self.query = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.key = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=bias)
        self.value = PrepareForMultiHeadAttention(d_model, heads, self.d_k, bias=True)

        # Softmax for attention along the time dimension of `key`
        # 也就是沿着key的列方向进行softmax
        self.softmax = nn.Softmax(dim=1)
        # Output layer
        self.output = nn.Linear(d_model, d_model)
        # Dropout
        self.dropout = nn.Dropout(dropout_prob)
        # Scaling factor before the softmax
        self.scale = 1 / math.sqrt(self.d_k)
        # We store attentions so that it can be used for logging, or other computations if needed
        self.attn = None

    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        """
        ### Calculate scores between queries and keys
        This method can be overridden for other variations like relative attention.
        """
        # Calculate $Q K^\top$ or $S_{ijbh} = \sum_d Q_{ibhd} K_{jbhd}$
        # 计算K^T 与 Q的乘积
        return torch.einsum('ibhd,jbhd->ijbh', query, key)

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        """
        `mask` has shape `[seq_len_q, seq_len_k, batch_size]`, where first dimension is the query dimension.
        If the query dimension is equal to $1$ it will be broadcasted.
        """

        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        # Same mask applied to all heads.
        mask = mask.unsqueeze(-1)

        # resulting mask has shape `[seq_len_q, seq_len_k, batch_size, heads]`
        return mask

    def forward(self, *,
                query: torch.Tensor,
                key: torch.Tensor,
                value: torch.Tensor,
                mask: Optional[torch.Tensor] = None):
        """
        `query`, `key` and `value` are the tensors that store
        collection of *query*, *key* and *value* vectors.
        They have shape `[seq_len, batch_size, d_model]`.

        `mask` has shape `[seq_len, seq_len, batch_size]` and
        `mask[i, j, b]` indicates whether for batch `b`,
        query at position `i` has access to key-value at position `j`.
        """

        # `query`, `key` and `value`  have shape `[seq_len, batch_size, d_model]`
        seq_len, batch_size, _ = query.shape
        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)
        # Prepare `query`, `key` and `value` for attention computation.
        # These will then have shape `[seq_len, batch_size, heads, d_k]`.
        query = self.query(query)
        key = self.key(key)
        value = self.value(value)
        # Compute attention scores $Q K^\top$.
        # This gives a tensor of shape `[seq_len, seq_len, batch_size, heads]`.
        scores = self.get_scores(query, key)
        # Scale scores $\frac{Q K^\top}{\sqrt{d_k}}$
        scores *= self.scale
        # Apply mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        # $softmax$ attention along the key sequence dimension
        # $\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)$
        attn = self.softmax(scores)
        # Save attentions if debugging
        tracker.debug('attn', attn)
        # Apply dropout
        attn = self.dropout(attn)
        # Multiply by values
        # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_k}}\Bigg)V$$
        x = torch.einsum("ijbh,jbhd->ibhd", attn, value)
        # Save attentions for any other calculations 
        self.attn = attn.detach()
        # Concatenate multiple heads
        x = x.reshape(seq_len, batch_size, -1)

        # Output layer
        return self.output(x)

方法解释

  • get_scores方法通过torch.einsum计算所有query和所有key的点积,得到注意力分数矩阵。

torch.einsum是一个非常强大的函数,用于执行爱因斯坦求和约定(Einstein summation convention),它提供了一种表达多维数组之间复杂操作的简洁方式。torch.einsum接受一个操作字符串和若干个张量作为输入,操作字符串指定了输入张量的维度如何相乘和求和。

在多头注意力的上下文中,torch.einsum('ibhd,jbhd->ijbh', query, key)这行代码执行了query矩阵和key矩阵的批量点乘操作,为了生成一个表示注意力分数的矩阵。下面是对这个操作字符串和相应操作的详细解释:

torch.einsum(‘ibhd,jbhd->ijbh’, query, key)

  • 输入张量:
    • querykey是两个张量,它们的维度都是[seq_len, batch_size, heads, d_k],其中:
      • seq_len是序列长度,
      • batch_size是批量大小,
      • heads是注意力头的数量,
      • d_k是每个注意力头的维度。
  • 操作字符串:'ibhd,jbhd->ijbh'可以分解为三个部分:
    • ibhd: 第一个张量(query)的维度标签。
    • jbhd: 第二个张量(key)的维度标签。
    • ijbh: 输出张量的维度标签。

解析操作字符串

  • ibhd,jbhd: 这表示querykey张量进行操作。每个张量的维度用不同的字母表示,相同的字母表示这些维度将进行点乘操作。
    • ij分别代表querykey的序列长度维度。
    • b代表批量大小(两个张量共享这一维度)。
    • h代表头的数量(两个张量共享这一维度)。
    • d代表每个头内的特征或维度,querykey在这一维度上进行点乘。
  • ->ijbh: 输出张量的维度。这里没有d,因为d维度上的元素被求和了(点乘后求和)。输出张量的维度是:
    • ij分别代表query的序列长度和key的序列长度,这允许每个query与所有key进行比较,形成一个注意力分数矩阵。
    • b代表批量大小。
    • h代表头的数量。

操作含义

这个操作计算了每个头内,每个query向量与每个key向量的乘积,并将结果求和(因为维度d没有出现在输出中)。这相当于计算注意力机制中的原始分数(未缩放的点乘注意力分数)。

对于每个批次中的每个头,你会得到一个[seq_len, seq_len]的分数矩阵,表示序列中每个位置的query如何与序列中每个位置的key相互作用。这个分数矩阵接下来会被缩放、掩码处理(如果有的话),然后应用softmax函数来得到最终的注意力权重。

  • prepare_mask方法用于处理掩码张量,使其适用于序列长度和头的维度。这在处理不同长度的序列时非常有用,可以阻止模型看到序列中的某些部分。也就是masked multi-head attention
  • forward方法是执行多头注意力计算的主要函数。它首先调整querykeyvalue的形状以适应多头计算,然后计算注意力分数,应用缩放和掩码,最后通过softmax获取注意力权重。使用这些权重和value计算加权和,最后通过输出层将结果映射回原始维度。

注意力计算过程

  1. 使用PrepareForMultiHeadAttentionquerykeyvalue进行线性变换并分头处理。
  2. 计算querykey的乘积,得到注意力分数矩阵。
  3. 应用缩放因子。
  4. 如果提供了掩码,应用掩码。
  5. 对分数应用softmax函数,得到注意力权重。
  6. 应用dropout到注意力权重上。
  7. 使用注意力权重对value进行加权求和。
  8. 将多个头的输出拼接并通过最后的线性层。

这个实现允许模型在不同的表示子空间中并行捕获信息,这是Transformer架构的关键特性之一,提高了处理复杂信息的能力。

  • 12
    点赞
  • 25
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值