Transformer

1. Transformer 原论文架构逐步实现

1.1 总体架构简介

Transformer 架构由编码器(Encoder)和解码器(Decoder)组成,最核心的创新是“Self-Attention”机制。其数学表达式和关键步骤如下:

  • 输入经过嵌入(Embedding),加上位置编码(Positional Encoding);

  • 多层堆叠的 Self-Attention + 前馈网络(Feed Forward, FFN);

  • 每层都有残差连接(Residual)与 LayerNorm;

  • Decoder 除了自注意力还有 Encoder-Decoder Attention。

我们先从 Encoder 开始,聚焦核心部件。


1.2 构建基本组件

1.2.1 位置编码(Positional Encoding)

数学原理:位置编码用于为输入序列每个Token添加位置信息,使模型具备顺序感知能力。原论文采用正弦和余弦函数。

公式如下:

PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right)

代码实现

import torch
import torch.nn as nn
import math

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # (1, max_len, d_model)
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x shape: (batch_size, seq_len, d_model)
        x = x + self.pe[:, :x.size(1)]
        return x

1.2.2 Scaled Dot-Product Attention(缩放点积注意力)

数学公式:

\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V

其中 Q(查询)、K(键)、V(值)由输入线性变换获得,d_k为每个头的维度。

代码实现:

def attention(q, k, v, mask=None, dropout=None):
    d_k = q.size(-1)
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    attn = torch.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    output = torch.matmul(attn, v)
    return output, attn

1.2.3 Multi-Head Attention

数学公式

\text{Concat}(\text{head}_1,\dots,\text{head}_h)W^O \\ \text{where}\ \text{head}_i = \text{Attention}(QW^Q_i, KW^K_i, VW^V_i)

每个“头”都有独立的参数。

代码实现

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.linear_q = nn.Linear(d_model, d_model)
        self.linear_k = nn.Linear(d_model, d_model)
        self.linear_v = nn.Linear(d_model, d_model)
        self.linear_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)
        q = self.linear_q(q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        k = self.linear_k(k).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
        v = self.linear_v(v).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

        x, attn = attention(q, k, v, mask, self.dropout)
        x = x.transpose(1,2).contiguous().view(batch_size, -1, self.num_heads * self.d_k)
        return self.linear_out(x)

1.2.4 前馈网络(Feed Forward Network)

数学公式:

\text{FFN}(x) = \max(0, xW_1 + b_1)W_2 + b_2

代码实现:

class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.linear2(self.dropout(torch.relu(self.linear1(x))))

1.2.5 LayerNorm 与残差

class AddNorm(nn.Module):
    def __init__(self, size, dropout=0.1):
        super().__init__()
        self.norm = nn.LayerNorm(size)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, sublayer):
        return x + self.dropout(sublayer(self.norm(x)))

1.2.6 编码器结构

堆叠N层,每层:MultiHeadAttention + AddNorm -> FFN + AddNorm

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads, dropout)
        self.addnorm1 = AddNorm(d_model, dropout)
        self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout)
        self.addnorm2 = AddNorm(d_model, dropout)

    def forward(self, x, mask=None):
        x = self.addnorm1(x, lambda x: self.attn(x, x, x, mask))
        x = self.addnorm2(x, self.ffn)
        return x

class Encoder(nn.Module):
    def __init__(self, layer, N):
        super().__init__()
        self.layers = nn.ModuleList([layer for _ in range(N)])
        self.norm = nn.LayerNorm(layer.attn.linear_q.in_features)

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

解码器仅需加上 Masked Attention 和 Encoder-Decoder Attention,可参考原论文或 Huggingface 代码。


1.2.7 完整Transformer

下面是简化版的 Encoder-only Transformer(比如BERT)组装方式:

class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len=512, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len)
        encoder_layer = EncoderLayer(d_model, num_heads, d_ff, dropout)
        self.encoder = Encoder(encoder_layer, num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, x, mask=None):
        x = self.embedding(x)     # (batch, seq, d_model)
        x = self.pos_encoding(x)
        x = self.encoder(x, mask)
        logits = self.fc_out(x)
        return logits

2. Transformer 关键公式简要回顾

  • Self-Attention: \text{Attention}(Q, K, V) = \text{softmax}\left( \frac{QK^T}{\sqrt{d_k}} \right)V 

  • Multi-Head:多个注意力头并行计算后拼接

  • 前馈网络:两层线性加激活

  • 残差连接: \text{LayerNorm}(x + \text{Sublayer}(x)) 


3. 近期大模型的关键改进与数学原理

接下来介绍并手写一些当前大模型领域常见且强力的新技术。


3.1 Multi-Head Latent Attention

背景:标准 Multi-Head Attention 是直接对序列所有位置做注意力。Latent Attention 是最近在Gemma、Qwen等新大模型中崛起的一种方案,核心思想是用一组“隐空间token”对原始序列进行稀疏投影与全局交互,可以提升长文本处理效率。

数学原理

假设输入为 X \in \mathbb{R}^{n \times d}(n为序列长度,d为特征维),引入一组 learnable latent tokens L \in \mathbb{R}^{m \times d}m \ll n

  • 第一步:令 X attend L,输出 L'

  • 第二步:L' attend X,输出 X'

代码实现伪码:

class LatentAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_latents, dropout=0.1):
        super().__init__()
        self.latents = nn.Parameter(torch.randn(1, num_latents, d_model))
        self.cross_attn1 = MultiHeadAttention(d_model, num_heads, dropout)
        self.cross_attn2 = MultiHeadAttention(d_model, num_heads, dropout)

    def forward(self, x, mask=None):
        # x: (batch, seq, d_model)
        bsz = x.size(0)
        latents = self.latents.expand(bsz, -1, -1)  # (batch, num_latents, d_model)
        latents = self.cross_attn1(latents, x, x, mask)  # Latents attend to input
        out = self.cross_attn2(x, latents, latents)      # Input attend to latents
        return out

优点:对长序列有更优的全局建模能力,计算复杂度 O(nm) 优于 O(n^2)


3.2 Swish GLU(Gated Linear Unit with Swish 激活)

背景:Feed Forward网络传统用ReLU或GELU,GLU结构近年在 Llama 等大模型中流行,用门控机制提升模型表达能力。Swish($x \cdot \sigma(x)$)激活能进一步提升非线性能力。

数学原理

\text{SwishGLU}(x) = (xW_1) \odot \sigma(xW_2)

其中 W_1, W_2 是不同权重,\sigma 是 Swish 激活。

代码实现:

class SwishGLU(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_model, d_ff)

    def forward(self, x):
        return self.fc1(x) * torch.sigmoid(self.fc2(x))

如果要 Swish 激活,也可写作 x \cdot \text{sigmoid}(x)

class FeedForwardSwishGLU(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_model, d_ff)
        self.fc3 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        # Swish GLU
        x1 = self.fc1(x)
        x2 = self.fc2(x)
        x = x1 * torch.sigmoid(x2)
        x = self.dropout(self.fc3(x))
        return x

3.3 旋转位置编码(Rotary Positional Embedding, RoPE)

背景:相比传统位置编码,RoPE 能更好地捕捉序列间相对位置信息,广泛应用于 Llama、ChatGLM 等模型。

数学原理

将每对通道映射到二维复空间,用旋转矩阵实现 token 的相对位置关系。简单说,就是将位置相关的相位旋转引入到 attention score 的计算中。

核心公式(略微简化):

\text{RoPE}(x, pos) = x_{even} \cdot \cos(\theta_{pos}) + x_{odd} \cdot \sin(\theta_{pos})

代码实现(以 Q、K 的旋转为例,假设 shape [batch, seq_len, num_heads, head_dim]):

def apply_rope(x, seq_dim=1):
    # x: (batch, seq, heads, head_dim)
    head_dim = x.shape[-1]
    half_dim = head_dim // 2
    freqs = torch.arange(half_dim, device=x.device).float() / half_dim
    inv_freq = 1.0 / (10000 ** freqs)
    positions = torch.arange(x.shape[seq_dim], device=x.device)
    sinusoid = torch.einsum('i,j->ij', positions, inv_freq)
    sin, cos = torch.sin(sinusoid), torch.cos(sinusoid)
    x1, x2 = x[..., :half_dim], x[..., half_dim:]
    x = torch.cat([x1 * cos + x2 * sin, x2 * cos - x1 * sin], dim=-1)
    return x

3.4 其它近期创新与趋势

a. FlashAttention

  • 数学原理:本质仍是 \text{softmax}(\frac{QK^T}{\sqrt{d_k}}) V,但实现层面大幅提升了显存与带宽效率,适合大模型长序列推理和训练。

  • 代码:PyTorch 2.0已内置 torch.nn.functional.scaled_dot_product_attention

b. Attention Mask 优化

  • 长文本支持:分块注意力、窗口注意力、稀疏注意力。

  • 多模态(如Qwen-VL)中,Attention Mask可灵活调整不同模态Token的交互方式。

c. 归一化新变种

  • RMSNorm(Root Mean Square LayerNorm):
    \text{RMSNorm}(x) = \frac{x}{\sqrt{\text{mean}(x^2) + \epsilon}} \cdot \gamma

  • ScaleNorm、AdaNorm等,更适合大模型。

d. 量化与稀疏化

  • 为大模型推理部署(如llm.int8(), GPTQ, AWQ等)节省内存与算力,数学原理涉及离散逼近和极大似然理论。


4. 参考与进阶阅读


总结

  1. 上述代码实现了原始 Transformer 架构主干,并对每一步的数学原理给出了详细公式和解释。

  2. 你学会了如何实现和替换为目前大模型常用的新技术:Latent Attention、Swish GLU、RoPE、FlashAttention等。

  3. 未来可进一步学习优化训练技巧、多模态融合、指令微调(SFT、DPO)、分布式并行等内容。

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值