序列并行技术解析(xtuner/deepspeed-uly)

序列并行是用于训练超长文本(64k, 128k等)的技术。一定长度上也可以扩展大模型的长文本能力。

Transformer激活值的显存计算

推荐看知乎分析transformer模型的参数量、计算量、中间激活、KV cache

从上述链接中可以获知, l l l层transformers的中间激活值占用的显存为: ( 34 b s h + 5 b s 2 a ) ∗ l (34bsh+5bs^{2}a)*l (34bsh+5bs2a)l,其中:

  • b b b: batch size
  • s s s: 序列长度
  • h h h: 隐藏层维度
  • a a a: 注意力头的个数

对于qwen1.5-32b-chat模型来说, h = 4096 , a = 8 ( G Q A ) , l = 64 h=4096, a = 8(GQA), l=64 h=4096,a=8(GQA),l=64,假设用batch size=1,训练长度为32k的序列。那么可以得到中间 l l l层的激活值为:
( 34 ∗ 1 ∗ 32000 ∗ 4096 + 5 ∗ 1 ∗ 32000 ∗ 32000 ∗ 8 ) ∗ 64 / 1000 / 1000 / 1000 = 2906 G (34*1*32000*4096+5*1*32000*32000*8) * 64 / 1000 / 1000/ 1000 = 2906G (341320004096+5132000320008)64/1000/1000/1000=2906G
即使是中间某一层的激活值,也有45G左右。这些激活值是需要存储在显存中以便反向传播求梯度的,开启gradient_checkpointing之后可以只存某些激活值,可以显著降低激活值的显存占用。

但是我们在qwen1.5-32b上尝试,即使开启gradient_checkpointing+deepspeed_zero3,在24卡A100上训练32k文本,也会在某一层layer计算attention score的时候报显存不足的错误,这种情况下并不是因为GPU卡太少,而是峰值显存占用量过大。

回顾Transformer Attention的计算过程

在这里插入图片描述

上图是attention的计算过程,分析下 a t t n o u t p u t attn_{output} attnoutput的占用显存量,以qwen1.5-32b为例, b s = 1 , n h = 8 ( G Q A ) , S = 32000 bs=1,nh=8(GQA),S=32000 bs=1,nh=8(GQA),S=32000,则 a t t n o u t p u t attn_{output} attnoutput占用的显存量为: 8 ∗ 32000 ∗ 32000 ∗ 2 / 1000 / 1000 / 1000 = 16 G 8*32000*32000*2/1000/1000/1000=16G 832000320002/1000/1000/1000=16G。若每一层都存储这样的矩阵,单张A100肯定是存不下的。必须将 a t t n o u t p u t attn_{output} attnoutput切分,放到不同的卡上。

仔细观察可以发现,我们只能从 a t t n o u t p u t attn_{output} attnoutput n h nh nh维度入手,将其切分为不同的块,放到不同的gpu上进行运算。

我们已经确定了, a t t n o u t p u t attn_{output} attnoutput只能在 n h nh nh的维度进行切分。按照自底向上(按照红色的标号)的顺序进行分析,我们可以发现,要达到这一目的,其实有多种方法:

  1. 对标号为4处的tensor的第二维进行切分。
  2. 对标号为2处的tensor的第一维进行切分。
  3. 对标号为3处的tensor的第二维进行切分。

但是以上3中方法并不是xtuner/deepspeed-uly中使用的方法。这两个包是在标号为5的tensor,也就是在 h i d d e n _ s t a t e s hidden\_states hidden_states上,在 S S S维度进行了切分,后面再通过变换,变换到 n h nh nh维度进行切分。这样可以进一步降低显存占用。

sequence parallel(序列并行)

在这里插入图片描述

序列并行如上图所示,sequence parallel size表示将 a t t n o u t p u t attn_{output} attnoutput放到几张卡上,设为 s p sp sp, 这 s p sp sp张卡称为一个 s p _ g r o u p sp\_group sp_group
在计算 q s t a t e s , k s t a t e s , v s t a t e s q_{states}, k_{states}, v_{states} qstates,kstates,vstates的过程中,我们需要将长度为 S S S序列,切分为 s p sp sp份,每份长度 S / s p S/sp S/sp,从position_ids的角度讲,

  • s p _ r a n k = 0 sp\_rank=0 sp_rank=0计算的position_ids的范围是 [ 0 , S / s p − 1 ] [0, S/sp-1] [0,S/sp1]
  • s p _ r a n k = 1 sp\_rank=1 sp_rank=1计算的position_ids的范围是 [ S / s p , 2 S / s p − 1 ] [S/sp, 2S/sp-1] [S/sp,2S/sp1]
  • s p _ r a n k = s p − 1 sp\_rank=sp-1 sp_rank=sp1计算的position_ids的范围是 [ S ( s p − 1 ) / s p , S − 1 ] [S(sp-1)/sp,S-1] [S(sp1)/sp,S1]

在进入 a t t e n t i o n _ f o r w a r d attention\_forward attention_forward计算 a t t n o u t p u t attn_{output} attnoutput之前,将切分维度从 S S S变换到 n h nh nh,这样每张卡上存储的 a t t n _ o u t p u t attn\_{output} attn_output的显存占用量只有原来的 1 / s p 1/sp 1/sp,分担了 n h / s p nh/sp nh/sp个注意力头的计算。但是在进行后续的计算之前,还需要将 a t t n _ o u t p u t attn\_{output} attn_output的维度由 [ b s , n h / s p , S , S ] [bs, nh/sp, S, S] [bs,nh/sp,S,S]变换到 [ b s , n h , S / s p , S / s p ] [bs, nh, S/sp, S/sp] [bs,nh,S/sp,S/sp]

DistributedSampler

在分布式计算时,需要用到DistributedSampler,该类用于对每张卡进行样本的分配。

class DistributedSampler(Sampler[T_co]):
    r"""Sampler that restricts data loading to a subset of the dataset.

    It is especially useful in conjunction with
    :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
    process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
    :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
    original dataset that is exclusive to it.

    .. note::
        Dataset is assumed to be of constant size and that any instance of it always
        returns the same elements in the same order.

    Args:
        dataset: Dataset used for sampling.
        num_replicas (int, optional): Number of processes participating in
            distributed training. By default, :attr:`world_size` is retrieved from the
            current distributed group.
        rank (int, optional): Rank of the current process within :attr:`num_replicas`.
            By default, :attr:`rank` is retrieved from the current distributed
            group.
        shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
            indices.
        seed (int, optional): random seed used to shuffle the sampler if
            :attr:`shuffle=True`. This number should be identical across all
            processes in the distributed group. Default: ``0``.
        drop_last (bool, optional): if ``True``, then the sampler will drop the
            tail of the data to make it evenly divisible across the number of
            replicas. If ``False``, the sampler will add extra indices to make
            the data evenly divisible across the replicas. Default: ``False``.

    .. warning::
        In distributed mode, calling the :meth:`set_epoch` method at
        the beginning of each epoch **before** creating the :class:`DataLoader` iterator
        is necessary to make shuffling work properly across multiple epochs. Otherwise,
        the same ordering will be always used.

    Example::

        >>> # xdoctest: +SKIP
        >>> sampler = DistributedSampler(dataset) if is_distributed else None
        >>> loader = DataLoader(dataset, shuffle=(sampler is None),
        ...                     sampler=sampler)
        >>> for epoch in range(start_epoch, n_epochs):
        ...     if is_distributed:
        ...         sampler.set_epoch(epoch)
        ...     train(loader)
    """

    def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
                 rank: Optional[int] = None, shuffle: bool = True,
                 seed: int = 0, drop_last: bool = False) -> None:
        if num_replicas is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            num_replicas = dist.get_world_size()
        if rank is None:
            if not dist.is_available():
                raise RuntimeError("Requires distributed package to be available")
            rank = dist.get_rank()
        if rank >= num_replicas or rank < 0:
            raise ValueError(
                f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
        self.dataset = dataset
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.drop_last = drop_last
        # If the dataset length is evenly divisible by # of replicas, then there
        # is no need to drop any data, since the dataset will be split equally.
        if self.drop_last and len(self.dataset) % self.num_replicas != 0:  # type: ignore[arg-type]
            # Split to nearest available length that is evenly divisible.
            # This is to ensure each rank receives the same amount of data when
            # using this Sampler.
            self.num_samples = math.ceil(
                (len(self.dataset) - self.num_replicas) / self.num_replicas  # type: ignore[arg-type]
            )
        else:
            self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)  # type: ignore[arg-type]
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle
        self.seed = seed

    def __iter__(self) -> Iterator[T_co]:
        if self.shuffle:
            # deterministically shuffle based on epoch and seed
            g = torch.Generator()
            g.manual_seed(self.seed + self.epoch)
            indices = torch.randperm(len(self.dataset), generator=g).tolist()  # type: ignore[arg-type]
        else:
            indices = list(range(len(self.dataset)))  # type: ignore[arg-type]

        if not self.drop_last:
            # add extra samples to make it evenly divisible
            padding_size = self.total_size - len(indices)
            if padding_size <= len(indices):
                indices += indices[:padding_size]
            else:
                indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
        else:
            # remove tail of data to make it evenly divisible.
            indices = indices[:self.total_size]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self) -> int:
        return self.num_samples

    def set_epoch(self, epoch: int) -> None:
        r"""
        Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
        use a different random ordering for each epoch. Otherwise, the next iteration of this
        sampler will yield the same ordering.

        Args:
            epoch (int): Epoch number.
        """
        self.epoch = epoch

在init函数中,可以看到有num_replicas和rank:

  • num_replicas:表示总共有几张卡
  • rank: 表示是第几张卡(global rank)。
    比如,用3个node,每个node 8张A100,那么对于第二个node的第三张卡:num_replicas=24, rank=10。

从代码中可以看到,pytorch官方实现的DistributedSampler,对于每张卡,分配的是不同的样本。前面提到,sequence parallel需要对同一个sp_group中的卡,分配相同的样本,之后对该样本切分成sp份后,每张卡得到对应的一份。因此,也就需要重新实现DistributedSampler。

一个例子

假设有3个node,每个node 8张A100,sequence parallel size=4,那么我们可以得到6个sp_group,每个sp_group完成一条样本的计算。也就是说,一次forward,只能计算6个样本(对于超长文本来说,batch_size只能设置为1)。

但是在计算梯度的时候,DDP模式是会除以总卡数的,这样是6个样本的平均梯度相比于24个样本的平均梯度必然会小。因此需要将gradient_accumulation_step设置为sequence parallel size
在这里插入图片描述

  • 7
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值