带你解锁 LLaMA 3 源代码的核心逻辑,附简化版实现

在大型语言模型的领域中,LLaMA 3 是 Meta 发布的一款具有强大性能的开源模型。本文将对 LLaMA 3 的源代码结构进行详细解析,并基于简化的 mini 版代码帮助大家理解模型的实现细节。最后,我们将通过一个小型的数据集进行模型训练的示例,展示如何应用这一模型。

一、LLaMA 3 源代码结构分析

LLaMA 3 基于 Transformer 架构,具备多头注意力机制、前馈网络、旋转位置编码等现代语言模型的关键技术。在代码结构上,LLaMA 3 通过模块化设计,分解成多个组件来实现每个部分的功能。让我们逐步分析每个模块的作用。

1. ModelArgs:模型参数类

LLaMA 模型首先定义了一个 ModelArgs 类,用来保存模型的各种超参数。这些参数包括嵌入维度、模型层数、多头注意力头数、词汇表大小等。代码结构如下:

class ModelArgs:
    def __init__(self, dim=4096, n_layers=32, n_heads=32, vocab_size=50257, max_seq_len=2048):
        self.dim = dim  # 嵌入维度
        self.n_layers = n_layers  # Transformer 层数
        self.n_heads = n_heads  # 多头注意力的头数
        self.vocab_size = vocab_size  # 词汇表大小
        self.max_seq_len = max_seq_len  # 最大序列长度

这些参数直接影响模型的规模和性能。例如,dim 决定了每个 token 嵌入向量的维度,n_layers 决定了 Transformer 堆叠的层数,而 n_heads 则决定了每一层注意力机制的复杂度。

2. RMSNorm:均方根归一化

LLaMA 3 使用了 RMSNorm(均方根归一化)代替传统的 LayerNorm。RMSNorm 可以更好地在深度网络中保持数值稳定性,尤其是在嵌入维度较大的情况下。

class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weight * norm

RMSNorm 会将输入的每个向量进行归一化操作,再通过可学习的权重进行缩放。它的效果类似于 LayerNorm,但计算方式更加简洁。

3. MultiheadAttention:多头注意力机制

多头注意力机制是 Transformer 架构的核心之一。LLaMA 的实现与标准的 Transformer 基本一致,每个注意力头会计算不同子空间上的 Query、Key 和 Value。

class MultiheadAttention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.n_heads = args.n_heads  # 多头数量
        self.head_dim = args.dim // args.n_heads  # 每个注意力头的维度
        assert self.head_dim * self.n_heads == args.dim  # 确保维度匹配

        self.q_proj = nn.Linear(args.dim, args.dim, bias=False)
        self.k_proj = nn.Linear(args.dim, args.dim, bias=False)
        self.v_proj = nn.Linear(args.dim, args.dim, bias=False)
        self.out_proj = nn.Linear(args.dim, args.dim, bias=False)

    def forward(self, x):
        B, T, C = x.shape  # B: batch size, T: seq_len, C: embedding dim
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_probs, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)

每个输入经过 Query、Key 和 Value 的线性变换后,计算注意力得分,并利用 softmax 生成权重矩阵,最终对 Value 加权求和,输出多头注意力的结果。

4. FeedForward:前馈网络

Transformer 的每个层中除了多头注意力,还包含一个前馈网络,它用于对每个 token 进行独立的非线性变换。

class FeedForward(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.fc1 = nn.Linear(args.dim, 4 * args.dim)  # 扩展 4 倍的隐藏层维度
        self.fc2 = nn.Linear(4 * args.dim, args.dim)  # 将维度映射回嵌入维度

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))  # 使用 GELU 激活函数,增加非线性

FeedForward 网络由两层线性变换组成,激活函数使用 GELU,它在性能上优于 ReLU。

5. TransformerBlock:Transformer 层

一个完整的 Transformer 层由多头注意力、前馈网络和归一化层构成,每个部分之间都有残差连接。

class TransformerBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.attn = MultiheadAttention(args)
        self.norm1 = RMSNorm(args.dim)
        self.ffn = FeedForward(args)
        self.norm2 = RMSNorm(args.dim)

    def forward(self, x):
        attn_out = x + self.attn(self.norm1(x))
        return attn_out + self.ffn(self.norm2(attn_out))

在这个 TransformerBlock 中,输入数据首先经过注意力层和归一化,再经过前馈网络。每个子层都有残差连接,以保持梯度稳定性。

6. Transformer:完整的 Transformer 模型

LLaMA 3 的核心是多个 Transformer 层的堆叠,每个层之间依次传递上下文信息。

class Transformer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.blocks = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
        self.norm = RMSNorm(args.dim)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    def forward(self, tokens):
        x = self.tok_embeddings(tokens)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return self.output(x)

这个 Transformer 模型包含了:

  • 嵌入层:将输入 token 转换为向量。
  • 多层 TransformerBlock:处理序列中的上下文信息。
  • 输出层:将最终的向量映射到词汇表中的 token 概率分布。

二、简化版 mini LLaMA 代码

为了帮助大家更好地理解 LLaMA 的代码结构,这里提供一个简化版的 mini LLaMA 模型。我们将使用 WikiText-2 数据集,并通过 Hugging Face 的 datasetstransformers 库加载数据和分词器。

import math

import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
from torch import nn, optim
from tqdm import tqdm
import torch.nn.functional as F

# 使用的WikiText-2数据集
dataset = load_dataset("wikitext", "wikitext-2-v1")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  # 使用BERT分词器作为示例


# 模型参数配置
class ModelArgs:
    def __init__(self, dim=512, n_layers=6, n_heads=8, vocab_size=30522,
                 max_seq_len=128):  # vocab_size = 30522 for BERT
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len


args = ModelArgs()


# 自定义数据集类,适用于语言模型任务
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, seq_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.seq_len = seq_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]["text"]
        tokens = self.tokenizer.encode(text, max_length=self.seq_len, truncation=True, padding="max_length")
        return torch.tensor(tokens), torch.tensor(tokens)  # 自回归任务:输入序列预测自身


# 将训练和验证集转换为DataLoader
train_data = TextDataset(dataset["train"], tokenizer, args.max_seq_len)
valid_data = TextDataset(dataset["validation"], tokenizer, args.max_seq_len)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=32)


# 模型定义与构建
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weight * norm


class MultiheadAttention(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.n_heads = args.n_heads
        self.head_dim = args.dim // args.n_heads
        assert self.head_dim * self.n_heads == args.dim

        self.q_proj = nn.Linear(args.dim, args.dim, bias=False)
        self.k_proj = nn.Linear(args.dim, args.dim, bias=False)
        self.v_proj = nn.Linear(args.dim, args.dim, bias=False)
        self.out_proj = nn.Linear(args.dim, args.dim, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_output = torch.matmul(attn_probs, v)
        attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, C)
        return self.out_proj(attn_output)


class FeedForward(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.fc1 = nn.Linear(args.dim, 4 * args.dim)
        self.fc2 = nn.Linear(4 * args.dim, args.dim)

    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))


class TransformerBlock(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.attn = MultiheadAttention(args)
        self.norm1 = RMSNorm(args.dim)
        self.ffn = FeedForward(args)
        self.norm2 = RMSNorm(args.dim)

    def forward(self, x):
        attn_out = x + self.attn(self.norm1(x))
        return attn_out + self.ffn(self.norm2(attn_out))


class Transformer(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim)
        self.blocks = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
        self.norm = RMSNorm(args.dim)
        self.output = nn.Linear(args.dim, args.vocab_size, bias=False)

    def forward(self, tokens):
        x = self.tok_embeddings(tokens)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        return self.output(x)


# 初始化模型、损失函数和优化器
model = Transformer(args)  # 新建模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=1e-4)  # Adam 优化器


# 训练函数
def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc="Training"):
        inputs, targets = batch
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model.to(device)(inputs)

        # 计算损失并进行反向传播
        loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)


# 验证函数
def validate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validating"):
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            loss = criterion(outputs.view(-1, outputs.size(-1)), targets.view(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)


# 训练模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
epochs = 3  # 训练3个epoch
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")

    train_loss = train(model, train_loader, optimizer, criterion, device)
    print(f"Training Loss: {train_loss:.4f}")

    valid_loss = validate(model, valid_loader, criterion, device)
    print(f"Validation Loss: {valid_loss:.4f}")

代码解析

  1. 加载数据:通过 Hugging Face 的 datasets 库加载 WikiText-2 数据集,并使用 BERT 的分词器进行 token 化。

  2. 定义 Transformer 模型:我们简化了模型的结构,保持了 LLaMA 3 的关键组件(RMSNorm、多头注意力、前馈网络等)。

  3. 训练与验证:使用交叉熵损失函数和 Adam 优化器,训练模型并在验证集上评估其性能。

总结

通过简化的 mini 版 LLaMA 代码,我们成功解析了 LLaMA 3 的核心架构和实现逻辑。该模型基于 Transformer,结合了现代 NLP 技术,如多头注意力机制、RMSNorm 和前馈网络等。对于想要理解和复现 LLaMA 3 的开发者,本文提供的代码和解析将是一个良好的起点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值