Transformer中MultiheadAttention代码实现

值得注意的几个点:

  • 词输出维度必须整除head个数以便计算
  • forward中进行转置操作为了后续点乘方便
  • mask只在decoder-enccoder中使用
class MultiheadAttention(nn.Module):
    # n_heads:多头注意力的数量
    # hid_dim:每个词输出的向量维度
    def __init__(self, hid_dim, n_heads, dropout):
        super(MultiheadAttention, self).__init__()
        self.hid_dim = hid_dim
        self.n_heads = n_heads

        # 强制 hid_dim 必须整除 h
        assert hid_dim % n_heads == 0
        # 定义 W_q 矩阵
        self.w_q = nn.Linear(hid_dim, hid_dim)
        # 定义 W_k 矩阵
        self.w_k = nn.Linear(hid_dim, hid_dim)
        # 定义 W_v 矩阵
        self.w_v = nn.Linear(hid_dim, hid_dim)
        self.fc = nn.Linear(hid_dim, hid_dim)
        self.do = nn.Dropout(dropout)
        # 缩放
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads]))

    def forward(self, query, key, value, mask=None):
        # 注意 Q,K,V的在句子长度这一个维度的数值可以一样,可以不一样。
        # K: [64,10,300], 假设batch_size 为 64,有 10 个词,每个词的 Query 向量是 300 维
        # V: [64,10,300], 假设batch_size 为 64,有 10 个词,每个词的 Query 向量是 300 维
        # Q: [64,12,300], 假设batch_size 为 64,有 12 个词,每个词的 Query 向量是 300 维
        bsz = query.shape[0]
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
        # 这里把 K Q V 矩阵拆分为多组注意力
        # 最后一维就是是用 self.hid_dim // self.n_heads 来得到的,表示每组注意力的向量长度, 每个 head 的向量长度是:300/6=50
        # 64 表示 batch size,6 表示有 6组注意力,10 表示有 10 词,50 表示每组注意力的词的向量长度
        # K: [64,10,300] 拆分多组注意力 -> [64,10,6,50] 转置得到 -> [64,6,10,50]
        # V: [64,10,300] 拆分多组注意力 -> [64,10,6,50] 转置得到 -> [64,6,10,50]
        # Q: [64,12,300] 拆分多组注意力 -> [64,12,6,50] 转置得到 -> [64,6,12,50]
        # 转置是为了把注意力的数量 6 放到前面,把 10 和 50 放到后面,方便下面计算
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim //
                   self.n_heads).permute(0, 2, 1, 3)

        # 第 1 步:Q 乘以 K的转置,除以scale
        # [64,6,12,50] * [64,6,50,10] = [64,6,12,10]
        # attention:[64,6,12,10]
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale

        # 如果 mask 不为空,那么就把 mask 为 0 的位置的 attention 分数设置为 -1e10,这里用“0”来指示哪些位置的词向量不能被attention到,比如padding位置,当然也可以用“1”或者其他数字来指示,主要设计下面2行代码的改动。
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)

        # 第 2 步:计算上一步结果的 softmax,再经过 dropout,得到 attention。
        # 注意,这里是对最后一维做 softmax,也就是在输入序列的维度做 softmax
        # attention: [64,6,12,10]
        attention = self.do(torch.softmax(attention, dim=-1))

        # 第三步,attention结果与V相乘,得到多头注意力的结果
        # [64,6,12,10] * [64,6,10,50] = [64,6,12,50]
        # x: [64,6,12,50]
        x = torch.matmul(attention, V)

        # 因为 query 有 12 个词,所以把 12 放到前面,把 50 和 6 放到后面,方便下面拼接多组的结果
        # x: [64,6,12,50] 转置-> [64,12,6,50]
        x = x.permute(0, 2, 1, 3).contiguous()
        # 这里的矩阵转换就是:把多组注意力的结果拼接起来
        # 最终结果就是 [64,12,300]
        # x: [64,12,6,50] -> [64,12,300]
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        x = self.fc(x)
        return x


# batch_size 为 64,有 12 个词,每个词的 Query 向量是 300 维
query = torch.rand(64, 12, 300)
# batch_size 为 64,有 12 个词,每个词的 Key 向量是 300 维
key = torch.rand(64, 10, 300)
# batch_size 为 64,有 10 个词,每个词的 Value 向量是 300 维
value = torch.rand(64, 10, 300)
attention = MultiheadAttention(hid_dim=300, n_heads=6, dropout=0.1)
output = attention(query, key, value)
## output: torch.Size([64, 12, 300])
print(output.shape)
<think>嗯,用户让我详细介绍一下Transformer模型的代码实现设计。首先,我需要回想一下Transformer的基本结构。Transformer由编码器和解码器组成,每个都有多层。编码器层包括自注意力机制和前馈网络,解码器层则多了一个交叉注意力。 用户可能需要的是代码层面的详细设计,比如各个模块如何实现。我应该从嵌入层开始,然后是位置编码,这部分可能用正弦和余弦函数,或者学习的位置嵌入。接着是注意力机制,特别是多头注意力,需要拆分成多个头,计算缩放点积注意力,然后合并。 然后,编码器和解码器的层结构,每个层里的子层连接和残差连接,LayerNormalization的应用。前馈网络通常是两个线性层,中间有激活函数。解码器部分还需要处理掩码,防止看到未来的信息。 可能还需要提到如何堆叠这些层,构建完整的Transformer模型。还有训练时的细节,比如损失函数、优化器,以及处理变长序列的掩码方法。用户可能对具体实现中的技巧感兴趣,比如参数初始化、学习率调整,或者如何在PyTorch中实现这些模块。 需要确认用户是否有框架偏好,比如TensorFlow或PyTorch,但问题里没提,所以保持通用。另外,用户可能希望了解各组件如何配合,比如数据流经过嵌入、位置编码、多头注意力的处理过程,以及训练和推理时的不同处理,比如解码器的自回归生成。 还要注意代码的结构,比如将每个组件模块化,方便复用。例如,MultiHeadAttention类、PositionalEncoding类、EncoderLayer、DecoderLayer等。最后,可能给出一个简化的代码框架,帮助用户理解整体结构。 需要避免过于深入数学细节,而侧重代码结构。但也要确保关键步骤如缩放点积注意力、掩码的应用等被解释清楚。可能还要提到实际应用中的常见问题,比如内存和计算效率,如何优化,比如使用键值缓存等。 总结下来,回答的结构应该是先概述Transformer的组成,然后分模块讲解每个部分的实现细节,最后可能给出一个代码示例,并讨论相关注意事项。</think>Transformer 模型的代码实现涉及多个核心模块的设计与组合。以下从架构设计、关键模块实现代码结构三个方面进行详细说明,以 PyTorch 框架为例: --- ### 一、整体架构设计 ```python class Transformer(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1): super().__init__() self.encoder = Encoder(src_vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, dropout) self.decoder = Decoder(tgt_vocab_size, d_model, nhead, num_decoder_layers, dim_feedforward, dropout) self.final_proj = nn.Linear(d_model, tgt_vocab_size) ``` --- ### 二、核心模块实现细节 #### 1. 嵌入层 + 位置编码 ```python class PositionalEncoding(nn.Module): def __init__(self, d_model, max_len=5000): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) # 不参与训练的位置编码 def forward(self, x): x = x + self.pe[:x.size(1), :] # 形状匹配 [seq_len, d_model] return x ``` #### 2. 多头注意力机制 ```python class MultiHeadAttention(nn.Module): def __init__(self, d_model, nhead, dropout=0.1): super().__init__() assert d_model % nhead == 0 self.d_k = d_model // nhead self.nhead = nhead self.q_linear = nn.Linear(d_model, d_model) self.k_linear = nn.Linear(d_model, d_model) self.v_linear = nn.Linear(d_model, d_model) self.dropout = nn.Dropout(dropout) self.out_linear = nn.Linear(d_model, d_model) def forward(self, q, k, v, mask=None): # 拆分多头 [batch_size, seq_len, d_model] -> [batch_size, seq_len, nhead, d_k] q = self.q_linear(q).view(q.size(0), -1, self.nhead, self.d_k).transpose(1,2) k = self.k_linear(k).view(k.size(0), -1, self.nhead, self.d_k).transpose(1,2) v = self.v_linear(v).view(v.size(0), -1, self.nhead, self.d_k).transpose(1,2) # 计算缩放点积注意力 scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) if mask is not None: scores = scores.masked_fill(mask == 0, -1e9) attn = F.softmax(scores, dim=-1) attn = self.dropout(attn) context = torch.matmul(attn, v) # 合并多头 context = context.transpose(1,2).contiguous().view(q.size(0), -1, self.nhead*self.d_k) return self.out_linear(context) ``` #### 3. 前馈网络(FFN) ```python class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, dim_feedforward=2048, dropout=0.1): super().__init__() self.linear1 = nn.Linear(d_model, dim_feedforward) self.linear2 = nn.Linear(dim_feedforward, d_model) self.dropout = nn.Dropout(dropout) self.activation = nn.ReLU() def forward(self, x): x = self.linear1(x) x = self.activation(x) x = self.dropout(x) x = self.linear2(x) return x ``` #### 4. 编码器层 ```python class EncoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward, dropout): super().__init__() self.self_attn = MultiHeadAttention(d_model, nhead, dropout) self.ffn = PositionwiseFeedForward(d_model, dim_feedforward, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, src, src_mask=None): # 自注意力子层 src2 = self.self_attn(src, src, src, src_mask) src = src + self.dropout1(src2) src = self.norm1(src) # 前馈子层 src2 = self.ffn(src) src = src + self.dropout2(src2) src = self.norm2(src) return src ``` #### 5. 解码器层(含掩码机制) ```python class DecoderLayer(nn.Module): def __init__(self, d_model, nhead, dim_feedforward, dropout): super().__init__() self.self_attn = MultiHeadAttention(d_model, nhead, dropout) self.cross_attn = MultiHeadAttention(d_model, nhead, dropout) self.ffn = PositionwiseFeedForward(d_model, dim_feedforward, dropout) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward(self, tgt, memory, tgt_mask=None, memory_mask=None): # 自注意力(带未来掩码) tgt2 = self.self_attn(tgt, tgt, tgt, tgt_mask) tgt = tgt + self.dropout1(tgt2) tgt = self.norm1(tgt) # 交叉注意力(编码器输出作为k/v) tgt2 = self.cross_attn(tgt, memory, memory, memory_mask) tgt = tgt + self.dropout2(tgt2) tgt = self.norm2(tgt) # 前馈子层 tgt2 = self.ffn(tgt) tgt = tgt + self.dropout3(tgt2) tgt = self.norm3(tgt) return tgt ``` --- ### 三、关键实现技巧 1. **掩码生成**: - **Padding Mask**:处理变长序列时屏蔽无效位置 ```python def create_padding_mask(seq, pad_idx): return (seq != pad_idx).unsqueeze(1).unsqueeze(2) # [B, 1, 1, L] ``` - **Look-ahead Mask**:防止解码器看到未来信息 ```python def create_look_ahead_mask(size): return torch.triu(torch.ones(size, size), diagonal=1) == 0 ``` 2. **梯度控制**: ```python # 使用 Xavier 初始化 for p in model.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) ``` 3. **训练优化**: - 使用 Label Smoothing - 学习率 Warmup 策略 - 梯度裁剪(`torch.nn.utils.clip_grad_norm_`) --- ### 四、典型数据流 1. 输入序列 → 词嵌入 → 位置编码 2. 编码器处理 → 多头自注意力 → 前馈网络(重复N层) 3. 解码器处理 → 掩码自注意力 → 交叉注意力 → 前馈网络(重复N层) 4. 线性投影 → Softmax → 输出概率分布 --- ### 五、性能优化方向 1. **内存优化**:使用混合精度训练 (`torch.cuda.amp`) 2. **计算加速**:Flash Attention 实现 3. **批处理优化**:动态 Padding(使用 `torch.nn.utils.rnn.pad_sequence`) 4. **解码优化**:Beam Search 的并行实现 完整实现通常包含 500-1000 行代码,建议参考 [The Annotated Transformer](http://nlp.seas.harvard.edu/2018/04/03/attention.html) 等经典实现。实际工业级实现还需要考虑分布式训练、量化部署等更多工程细节。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值