详解PyTorch FSDP数据并行(Fully Sharded Data Parallel)

1. 背景介绍

全切片数据并行(Fully Sharded Data Parallel,简称为FSDP)是数据并行的一种新的方式,FSDP最早是在2021年在FairScale-FSDP中提出的,后来合入了PyTorch 1.11版本中。微软之前Deepspeed框架中提出过三种级别的ZERO算法,FSDP可以看成是ZERO-3的实现。

2. 详细介绍

传统的数据并行(DDP)是在每一个GPU卡上保存整个model的参数/梯度/优化器状态, 然后对数据集切分为 N N N 个shard分片给不同的GPU进行训练,计算完梯度后通过all-reduce通信来做梯度的融合。如下图:
在这里插入图片描述

在FSDP中的主要思路是想办法把model的梯度/优化器状态/参数都进行切分操作,每个GPU只存部分的参数信息,也就是在ZERO-3的思路。为了能把所有的参数进行分片处理,核心在于要把DDP中的all-reduce操作拆解为reduce-scatterall-gather 操作。

在这里插入图片描述

如下图,在进行FSDP前向计算其中的一层Layer时,由于每个GPU都只保存了部分参数,所以需要先通过all-gather操作获得全部的参数;同理,在反向计算过程中,也需要通过all-gather操作,获得全部的参数;最后计算出来的梯度只是部分的结果,需要通过reduce-scatter通信进行累加操作,最终每个GPU卡分别只更新自己那部分参数(也就是local本地weight更新)。

在这里插入图片描述

FSDP的应用是对原有model layers加上了一层wrapper封装,只有在FSDP实例中的layer才会在前向和后向过程中执行gather相关操作,通过切分可以利用相同的显存大小训练更大的模型。为了进一步提升显存利用率,FSDP也支持把不活跃的实例全部offload调出到CPU上去。

FSDP计算过程的伪码如下:

FSDP forward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        forward pass for layer_i
        discard full weights for layer_i

FSDP backward pass:
    for layer_i in layers:
        all-gather full weights for layer_i
        backward pass for layer_i
        discard full weights for layer_i
        reduce-scatter gradients for layer_i

在PyTorch中的示例如下, 通过FullyShardedDataParallel实现对model的封装,通过CPUOffload来决定采用哪种策略把参数调到CPU上。

from torch.distributed.fsdp import (
   FullyShardedDataParallel,
   CPUOffload,
)
from torch.distributed.fsdp.wrap import (
   default_auto_wrap_policy,
)
import torch.nn as nn
 
class model(nn.Module):
   def __init__(self):
       super().__init__()
       self.layer1 = nn.Linear(8, 4)
       self.layer2 = nn.Linear(4, 16)
       self.layer3 = nn.Linear(16, 4)
 
model = DistributedDataParallel(model())
fsdp_model = FullyShardedDataParallel(
   model(),
   fsdp_auto_wrap_policy=default_auto_wrap_policy,
   cpu_offload=CPUOffload(offload_params=True),
)

使用FSDP训练GPT-175B和GPT-1T参数量大小的模型,词表大小50K,fp16的精度和使用SGD的优化器。

在这里插入图片描述

结果如下,使用FSDP时在GPU卡数增大的情况下,对GPU单卡的吞叶没有影响;在A100-40G机器下增大batch_size 但吞吐没有增加, 瓶颈不在于通信而是CUDA cache的分配到了瓶颈;当换为A100-80G机器时,CUDA cache的分配问题得到解决后,增大batch_size后吞吐进一步增加。

在这里插入图片描述

3. 参考

  • 5
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qinduohao333

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值