nanochat代码讲解之三--优化器和数据加载器

在软件开发领域,需求管理一直是项目成功的核心关键。随着项目复杂度提升和团队规模扩大,传统依赖文档、邮件和会议的需求管理方式显露出明显短板:版本混乱、协作困难、知识难以沉淀。更值得注意的是,行业内能够真正实现需求结构化、资产化,并结合AI技术进行智能化辅助的系统并不多见。我们公司是一家垂直领域专攻企业级需求与非企业级需求管理的公司, 我们公司的大模型应用连接:http://aipoc.chtech.cn:8880/#/login 欢迎试用。

下面我们来了解一下nanochat的优化器和数据加载器:

1. adamw.py - 分布式AdamW优化器

python

"""
Borrowed from modded-nanogpt. By Keller, @vagrawal, et al.
Not a general optimizer! But works for our specific use.
从modded-nanogpt借用。不是通用优化器!但适用于我们的特定用例。
"""
​
import torch
import torch.distributed as dist
from torch import Tensor
​
class DistAdamW(torch.optim.Optimizer):
    """
    Distributed AdamW optimizer.
    分布式AdamW优化器。
    In the style of ZeRO-2, i.e. sharded optimizer states and gradient reduction
    采用ZeRO-2风格,即分片的优化器状态和梯度归约
    """
    def __init__(self, param_groups, lr: float = 1e-3, betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, weight_decay: float = 0.01):
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
        super().__init__(param_groups, defaults)
​
    @torch.compile  # 使用PyTorch 2.0的编译优化
    @torch.no_grad()
    def step(self):
        """执行优化步骤"""
        rank = dist.get_rank()      # 当前进程排名
        world_size = dist.get_world_size()  # 总进程数
        
        reduce_scatter_futures: list[torch.Future] = []  # 梯度归约-分散操作的future列表
        all_reduce_futures: list[torch.Future] = []      # 参数全归约操作的future列表
        grad_slices = []  # 梯度切片列表
​
        # 第一阶段:梯度归约-分散
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            grad = torch.empty_like(params[-1]) # TODO 这是bug吗?似乎被立即覆盖了
            for base_i in range(len(params)):
                grad = params[base_i].grad  # 获取参数梯度
                rank_size = grad.shape[0] // world_size  # 每个进程处理的梯度大小
                grad_slice = torch.empty_like(grad[:rank_size])  # 为梯度切片创建空张量
                
                # 异步执行归约-分散操作:将梯度平均分散到各个进程
                reduce_scatter_futures.append(
                    dist.reduce_scatter_tensor(
                        grad_slice, grad, op=dist.ReduceOp.AVG, async_op=True
                    ).get_future()
                )
                grad_slices.append(grad_slice)
​
        # 第二阶段:参数更新
        idx = 0  # 梯度切片索引
        for group in self.param_groups:
            beta1, beta2 = group['betas']  # Adam的动量参数
            eps = group['eps']             # 数值稳定性常数
            wd = group['weight_decay']     # 权重衰减
            params = group['params']
            
            for base in range(len(params)):
                # 等待梯度归约完成
                reduce_scatter_futures[idx].wait()
                p = params[base]  # 当前参数
                rank_size = p.shape[0] // world_size
                # 获取当前进程负责的参数切片
                p_slice = p[rank * rank_size:(rank + 1) * rank_size]
                lr = group['lr'] * getattr(p, "lr_mul", 1.0)  # 学习率,支持参数特定的乘子
                
                state = self.state[p]  # 参数状态(动量、二阶矩等)
                g_slice = grad_slices[idx]  # 对应的梯度切片
​
                # 状态初始化
                if not state:
                    state['step'] = torch.tensor(0, dtype=torch.int64, device=p.device)
                    state['exp_avg'] = torch.zeros_like(p_slice)      # 一阶动量
                    state['exp_avg_sq'] = torch.zeros_like(p_slice)   # 二阶动量
​
                exp_avg = state['exp_avg']      # 一阶动量(类似动量)
                exp_avg_sq = state['exp_avg_sq'] # 二阶动量(类似RMS)
                state['step'] += 1              # 步数计数
                t = state['step']               # 当前步数
​
                # 权重衰减(在梯度更新前应用)
                if wd != 0:
                    eff_weight_decay = lr * wd * getattr(p, "wd_mul", 1.0)  # 有效权重衰减
                    p_slice.mul_(1 - eff_weight_decay)  # 原地应用权重衰减
​
                # 更新运行平均值(动量更新)
                exp_avg.mul_(beta1).add_(g_slice, alpha=1 - beta1)           # 一阶动量
                exp_avg_sq.mul_(beta2).addcmul_(g_slice, g_slice, value=1 - beta2)  # 二阶动量
​
                # 偏差校正(Adam的特性)
                bias1 = 1 - beta1 ** t  # 一阶动量偏差校正
                bias2 = 1 - beta2 ** t  # 二阶动量偏差校正
​
                # 计算更新步长
                denom = exp_avg_sq.sqrt().add_(eps)  # 分母,添加eps防止除零
                step_size = lr * (torch.sqrt(bias2) / bias1)  # 带偏差校正的步长
                update = exp_avg.div(denom).mul_(step_size)   # 最终更新量
​
                # 应用更新
                p_slice.add_(other=update, alpha=-1.0)  # 参数更新
                
                idx += 1
                # 异步执行全归约操作:同步所有进程的更新后参数
                all_reduce_futures.append(
                    dist.all_gather_into_tensor(p, p_slice, async_op=True).get_future()
                )
        
        # 等待所有参数同步完成
        torch.futures.collect_all(all_reduce_futures).wait()

2. muon.py - Muon优化器

python

"""
Muon optimizer from Keller et al.
Also a lot of borrowing of ideas from modded-nanogpt.
Muon优化器来自Keller等人,也从modded-nanogpt借鉴了很多想法。
"""
​
import torch
from torch import Tensor
import torch.distributed as dist
​
@torch.compile
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
    """
    Newton-Schulz iteration to compute the zeroth power / orthogonalization of G.
    牛顿-舒尔茨迭代计算G的零次幂/正交化。
    We opt to use a quintic iteration whose coefficients are selected to maximize the slope at zero.
    我们选择使用五次迭代,其系数选择为在零点处最大化斜率。
    """
    assert G.ndim >= 2 # 批处理的Muon实现
    
    a, b, c = (3.4445, -4.7750,  2.0315)  # 优化后的系数
    X = G.bfloat16()  # 使用bfloat16以节省内存
    
    # 确保谱范数最多为1
    if G.size(-2) > G.size(-1):
        X = X.mT  # 转置以适应形状
​
    # 确保谱范数最多为1
    X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
    
    # 执行NS迭代
    for _ in range(steps):
        A = X @ X.mT
        B = b * A + c * A @ A  # 五次计算策略
        X = a * X + B @ X
​
    if G.size(-2) > G.size(-1):
        X = X.mT  # 转置回来
    return X
​
class Muon(torch.optim.Optimizer):
    """
    Muon - MomentUm Orthogonalized by Newton-schulz
    Muon - 通过牛顿-舒尔茨正交化的动量
    
    https://kellerjordan.github.io/posts/muon/
​
    Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
    processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
    matrix.
    Muon内部运行标准SGD动量,然后执行正交化后处理步骤,其中每个2D参数的更新被替换为最近的正交矩阵。
    """
    def __init__(self, params, lr=0.02, momentum=0.95, nesterov=True, ns_steps=5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        params: list[Tensor] = [*params]
        
        # 按参数大小分组,便于批处理
        param_groups = []
        for size in {p.numel() for p in params}:
            group = dict(params=[p for p in params if p.numel() == size])
            param_groups.append(group)
        super().__init__(param_groups, defaults)
​
    @torch.no_grad()
    def step(self):
        """执行优化步骤"""
        for group in self.param_groups:
            params: list[Tensor] = group["params"]
            for p in params:
                g = p.grad  # 梯度
                assert g is not None
                state = self.state[p]
                
                # 初始化动量缓冲区
                if "momentum_buffer" not in state:
                    state["momentum_buffer"] = torch.zeros_like(g)
                buf: Tensor = state["momentum_buffer"]
                
                # 动量更新: buf = momentum * buf + (1 - momentum) * g
                buf.lerp_(g, 1 - group["momentum"])
                
                # Nesterov动量: g = g + momentum * buf
                g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                
                # 正交化梯度
                g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                
                # 应用更新,考虑纵横比缩放
                scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
                p.add_(g, alpha=-group["lr"] * scale)
​
class DistMuon(torch.optim.Optimizer):
    """
    Distributed version of Muon optimizer.
    Muon优化器的分布式版本。
    """
    def __init__(self, params, lr: float = 0.02, momentum: float = 0.95,
                 nesterov: bool = True, ns_steps: int = 5):
        defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps)
        params = list(params)
        assert all(p.ndim == 2 for p in params), "Muon expects 2D parameters only"
        
        rank = dist.get_rank()
        
        # 按形状对所有参数分组
        shapes = sorted({p.shape for p in params})  # 排序以确保一致/确定性顺序
        param_groups = []
        for shape in shapes:
            group_params = [p for p in params if p.shape == shape]
            device, dtype = group_params[0].device, group_params[0].dtype
            assert all(p.device == device for p in group_params)
            assert all(p.dtype == dtype for p in group_params)
            
            if rank == 0:
                print(f"Muon: Grouping {len(group_params)} params of shape {shape}, device {device}, dtype {dtype}")
            
            param_groups.append(dict(params=group_params, zero_buffer=torch.zeros_like(group_params[0])))
        
        super().__init__(param_groups, defaults)
​
    @torch.no_grad()
    def step(self):
        """分布式优化步骤"""
        rank = dist.get_rank()
        world_size = dist.get_world_size()
​
        # 确保所有梯度都存在
        assert all(p.grad is not None for group in self.param_groups for p in group["params"]), "All params must have grads"
​
        # 启动所有归约-分散操作来平均各个rank的梯度
        all_reduce_futures = []
        for group in self.param_groups:
            params = group["params"]
            zero_buffer = group["zero_buffer"]
            
            # 按world_size分组处理参数
            for base_i in range(0, len(params), world_size):
                # 每个参数的计算所有者是 rank i % world_size
                owner_idx = base_i + rank
                
                # 每个rank将其world_size参数的块堆叠到列表中
                rs_input = [p.grad for p in params[base_i:base_i + world_size]]
                # 用零缓冲区填充rs_input以完成组
                rs_input.extend([zero_buffer] * (world_size - len(rs_input)))
                
                # 输出缓冲区基于rank在组中跨步
                rs_output = params[owner_idx].grad if owner_idx < len(params) else torch.empty_like(zero_buffer)
                
                # 在这个world_size参数组内归约-分散梯度
                work = dist.reduce_scatter(rs_output, rs_input, op=dist.ReduceOp.AVG, async_op=True).get_future()
                all_reduce_futures.append(work)
​
        # 现在每个rank计算更新并收集
        future_idx = 0
        all_gather_futures = []
        for group in self.param_groups:
            params = group["params"]
            zero_buffer = group["zero_buffer"]
            
            # 按world_size分组处理参数
            for base_i in range(0, len(params), world_size):
                # 每个参数的计算所有者是 rank i % world_size
                owner_idx = base_i + rank  # 计算此rank拥有的参数索引
                
                # 等待归约-分散完成
                all_reduce_futures[future_idx].wait()
                future_idx += 1
                
                # 所有者计算Muon更新,结果在其参数中
                if owner_idx < len(params):
                    p = params[owner_idx]
                    g = p.grad  # 现在已在各个rank间平均
                    state = self.state[p]
                    
                    if "momentum_buffer" not in state:
                        state["momentum_buffer"] = torch.zeros_like(g)
                    buf: Tensor = state["momentum_buffer"]
                    
                    # 动量更新
                    buf.lerp_(g, 1.0 - group["momentum"])
                    g = g.lerp_(buf, group["momentum"]) if group["nesterov"] else buf
                    
                    # 正交化
                    g = zeropower_via_newtonschulz5(g, steps=group["ns_steps"])
                    scale = (max(1.0, p.size(-2) / p.size(-1)) ** 0.5)
                    p.add_(g, alpha=-group["lr"] * scale)
                
                # 将更新后的参数复制到所有rank
                ag_input = params[owner_idx] if owner_idx < len(params) else zero_buffer
                ag_output = params[base_i:base_i + world_size]
                ag_output.extend([torch.empty_like(zero_buffer) for _ in range(world_size - len(ag_output))]) # 填充
                
                work = dist.all_gather(ag_output, ag_input, async_op=True).get_future()
                all_gather_futures.append(work)
​
        # 等待所有工作完成
        torch.futures.collect_all(all_gather_futures).wait()

3. dataloader.py - 数据加载器

python

from collections import deque
import torch
from nanochat.common import get_dist_info
from nanochat.dataset import parquets_iter_batched
from nanochat.tokenizer import get_tokenizer
​
def tokenizing_distributed_data_loader(B, T, split, tokenizer_threads=4, tokenizer_batch_size=128):
    """
    Stream pretraining text from parquet files, tokenize, yield training batches.
    从parquet文件流式传输预训练文本,进行分词,产生训练批次。
    
    Args:
        B: 批次大小 (batch size)
        T: 序列长度 (sequence length)  
        split: 数据集分割 ("train" 或 "val")
        tokenizer_threads: 分词器线程数
        tokenizer_batch_size: 分词器批处理大小
    """
    assert split in ["train", "val"], "split must be 'train' or 'val'"
    
    ddp, ddp_rank, ddp_local_rank, ddp_world_size = get_dist_info()
    needed_tokens = B * T + 1  # +1 是因为我们还需要最后一个token的目标
    
    # 获取分词器和bos token
    tokenizer = get_tokenizer()
    bos_token = tokenizer.get_bos_token_id()
    
    # 暂存缓冲区保存一次迭代的tokens
    token_buffer = deque()  # 我们在右侧流式传输tokens并从左侧弹出
    scratch = torch.empty(needed_tokens, dtype=torch.int64, pin_memory=True)  # 固定内存以加速传输
​
    # 文档批次的无限迭代器
    def document_batches():
        while True:
            # batch将按parquet文件的组大小迭代,通常例如1024行
            for batch in parquets_iter_batched(split=split, start=ddp_rank, step=ddp_world_size):
                # 对于分词器,我们可能希望使用通常更小的批次,例如128行
                for i in range(0, len(batch), tokenizer_batch_size):
                    yield batch[i:i+tokenizer_batch_size]
    
    batches = document_batches()
    batch_index = 0
​
    while True:
        # 在产生之前积累足够的tokens用于一次迭代
        while len(token_buffer) < needed_tokens:
            doc_batch = next(batches)  # 获取下一批文档
            # 使用多线程分词
            token_lists = tokenizer.encode(doc_batch, prepend=bos_token, num_threads=tokenizer_threads)
            
            for tokens in token_lists:
                token_buffer.extend(tokens)  # 将分词的tokens添加到缓冲区
                
            batch_index += 1
​
        # 将tokens从deque移动到暂存缓冲区
        for i in range(needed_tokens):
            scratch[i] = token_buffer.popleft()  # 从左侧弹出token
            
        # 创建输入/目标作为1D张量
        inputs_cpu = scratch[:-1].to(dtype=torch.int32)  # 输入:除最后一个token外的所有token
        targets_cpu = scratch[1:]                        # 目标:除第一个token外的所有token(自回归)
        
        # 重塑为2D并异步移动到GPU
        inputs = inputs_cpu.view(B, T).to(device="cuda", dtype=torch.int32, non_blocking=True)
        targets = targets_cpu.view(B, T).to(device="cuda", dtype=torch.int64, non_blocking=True)
        
        yield inputs, targets  # 产生训练批次

关键概念解释:

  1. 分布式优化器:

    • ZeRO-2: 将优化器状态分片到不同GPU,减少内存使用

    • reduce-scatter: 梯度归约和分散操作

    • all-gather: 参数同步操作

  2. Muon优化器特点:

    • 结合了SGD动量和矩阵正交化

    • 使用牛顿-舒尔茨迭代进行高效正交化

    • 特别适合Transformer的线性层

  3. 数据加载器设计:

    • 流式处理: 不从磁盘一次性加载所有数据

    • 异步操作: 数据预处理和GPU传输重叠

    • 分布式支持: 每个GPU处理数据的不同部分

  4. 内存优化技术:

    • 固定内存 (pin_memory): 加速CPU到GPU的数据传输

    • 非阻塞传输 (non_blocking): 允许计算和数据传输重叠

    • 缓冲区管理: 使用deque高效管理token流

这些组件共同构成了高效训练大型语言模型的基础设施。优化器负责参数更新策略,数据加载器确保高效的数据供给。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值