基于 DeepSpeed-Ulysses 的 SequenceParallel 介绍及软件实现

基于 DeepSpeed-Ulysses 的 SequenceParallel 介绍及软件实现

欢迎到访:ckblogs.cn

由于官方示例中使用的 deepspeed-ulysses 是 基于 megatron-deepspeed 的,而对于普通的 deepspeed 并没有例程,这里的实现仅按照个人想法实现,如果有问题,欢迎与我交流讨论。

方法介绍

从生成性AI到科研模型,长序列训练正在变得非常重要。 在生成性AI领域,会话式AI、长文档摘要和视频生成等任务都需要在空间和时间层面对长上下文进行推理。长序列长度的重要性逐渐增长,但现有的大型模型训练系统和底层的并行技术(数据、张量、流水线、序列并行)并不能支持高效的长序列训练。为了解决这些问题,微软DeepSpeed宣布推出 DeepSpeed-Ulysses,这是一种简单、易用且高效的方法,用于支持具有极长序列长度的高效可扩展LLM训练。

DeepSpeed-Ulysses将各个样本在序列维度上分割给参与的GPU。然后,在 attention 计算之前,它对已分割的 Q、K、V执行 all-to-all通信 操作,以使每个 GPU 接收完整的序列,但仅用于注意力头的非重叠子集。这使得参与的GPU可以并行计算不同的注意力头。最后,DeepSpeed-Ulysses 使用另一个 all-to-all 来在注意力头上收集结果,同时重新在序列维度上进行分区。

传统的 Attention 计算过程:

image-20240412103013786

DeepSpeed- Ulysses 计算过程:

image-20240412103042454

DeepSpeed- Ulysses 的核心在于切分 Q、K、V 后进行的 All-to-All 通信方式,这个通信方式同 Allreduce 一样,是分布式训练中的 Collective functions,All-to-All 在每个进程向每个其他进程发消息的一部分,最后处理器拥有各个进程消息的一部分。他的作用相当于分布式转置Transpose操作。

image-20240412120330612

用这个 pytorch 的示例更好理解:

>>> input = torch.arange(4) + rank * 4
>>> input
tensor([0, 1, 2, 3])     # Rank 0
tensor([4, 5, 6, 7])     # Rank 1
tensor([8, 9, 10, 11])   # Rank 2
tensor([12, 13, 14, 15]) # Rank 3
>>> output = torch.empty([4], dtype=torch.int64)
>>> dist.all_to_all_single(output, input)
>>> output
tensor([0, 4, 8, 12])    # Rank 0
tensor([1, 5, 9, 13])    # Rank 1
tensor([2, 6, 10, 14])   # Rank 2
tensor([3, 7, 11, 15])   # Rank 3

我对整个过程做了一个更直白的方式进行描述,我将这个过程分成两步:

  1. 第一步,对 K、Q、V,分别按照序列长度进行分割,得到 Local-Q、Local-K、Local-V,这里的 P 表示分割的数量,也就是对每一个 GPU 放入的序列长度是 N/P,通过 All-to-All 通信,获取序列长度为 N 但进行注意力头切割的 K h K_h Kh Q h Q_h Qh V h V_h Vh :

image-20240412105815699

  1. 第二步,计算 O h O_h Oh ,再通过 All-to-All 通信,获取 Local-O,在每个进程内,使用 Local O,进行 后续操作。
    O h = S o f t m a x ( Q h K h T ) V h O_h=Softmax(Q_hK_h^T)V_h Oh=Softmax(QhKhT)Vh
    image-20240412110745541

    这样去做的意义就在于对于每一个设备送入处理的序列长度只是 N/P,这样大大减少了设备的运算量,但同时,因为 All-To-All 通信方式的存在,使得即使每个设备被送入的序列长度只有 N/P,但他们做 Attention 时还是考虑了整体的长度,可谓巧妙。

软件实现

对于将已有的训练代码转为基于 DeepSpeed-Ulysses 的序列并行训练,主要修改的代码包括三个部分:

定义序列并行 group

在定义 data engine 之前,需要先定义 sequence_parallel_group ,这里的 sequence_parallel_size 指每一个序列并行内进程的数量,sequence_parallel_size 数值越大,一个 sequence_parallel_group 中的进程数量越多,每张卡里送入的序列长度就越小,假设 world_size 为 8,sequence_parallel_size 为 4,那么 world 将会被分成这样:

image-20240412112819123

_SEQUENCE_PARALLEL_GROUP = None

def initialize_model_parallel(
    sequence_parallel_size,

):
    world_size = dist.get_world_size()
    num_sequence_parallel_groups: int = world_size // sequence_parallel_size
    global _SEQUENCE_PARALLEL_GROUP
    for i in range(num_sequence_parallel_groups):
        ranks = range(i * sequence_parallel_size, (i + 1) * sequence_parallel_size)
        group = torch.distributed.new_group(ranks)
        rank = dist.get_rank()
        if rank in ranks:
            _SEQUENCE_PARALLEL_GROUP = group

def get_sequence_parallel_group():
    """Get the sequence parallel group the caller rank belongs to."""
    return _SEQUENCE_PARALLEL_GROUP

def get_sequence_parallel_world_size():
    """Get the sequence parallel world size."""
    return dist.get_world_size(group=get_sequence_parallel_group())

def get_sequence_parallel_rank():
    """Get the sequence parallel rank."""
    return dist.get_rank(group=get_sequence_parallel_group())

修改 Attention 实现

这里我们使用 FlashAttention2 的 Attention ,将原本的

def forward(self, ...):
	...
    attn_output = flash_attn_func(
        query_states,
        key_states,
        value_states,
        dropout,
        softmax_scale=softmax_scale,
        causal=causal,
    )
    ...

替换为:

def __init__(self, ...):
    ...
    self.dist_ulysser_attn = UlyssesAttention()
    ...
    
    
def forward(self, ...):
	...
    attn_output = self.dist_ulysser_attn(
        query_states,
        key_states,
        value_states,
        dropout,
        softmax_scale=softmax_scale,
        causal=causal,
    )
    ...

这里的 UlyssesAttention() 通过上述的 DeepSpeed-Ulysses 技巧实现,具体的 Attention 如下:

这里参考了 feifeibear/long-context-attention 的实现方式。

class UlyssesAttention(torch.nn.Module):
    """Initialization.

        scatter_idx (int): scatter_idx for all2all comm
        gather_idx (int): gather_idx for all2all comm
    """

    def __init__(
        self,
        scatter_idx: int = 2,
        gather_idx: int = 1,
    ) -> None:

        super(UlyssesAttention, self).__init__()
        self.spg = get_sequence_parallel_group()
        self.scatter_idx = scatter_idx
        self.gather_idx = gather_idx

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        *args: Any
    ) -> Tensor:

        q = SeqAllToAll4D.apply(self.spg, query, self.scatter_idx, self.gather_idx)
        k = SeqAllToAll4D.apply(self.spg, key, self.scatter_idx, self.gather_idx)
        v = SeqAllToAll4D.apply(self.spg, value, self.scatter_idx, self.gather_idx)

        context_layer = flash_attn_func(
            q,
            k,
            v,
            softmax_scale=softmax_scale,
            dropout_p=dropout_p,
            causal=causal,
        )

        if  isinstance(context_layer, tuple):
            context_layer = context_layer[0]

        output = SeqAllToAll4D.apply(
            self.spg, context_layer, self.gather_idx, self.scatter_idx
        )

        return output

修改训练传入数据——做序列长度分割

这里也是能够有效降低显存的主要原因,也就是送入每个 GPU 的序列长度被分割为了多个部分,每个GPU只处理其中一个部分:

...
# get sequence parallel sub_seq_start and sub_seq_end
seq_parallel_rank = get_sequence_parallel_rank()
seq_parallel_world_size = get_sequence_parallel_world_size()
seq_length = self.max_length // seq_parallel_world_size
self.sub_seq_start = seq_parallel_rank * seq_length
self.sub_seq_end = (seq_parallel_rank + 1) * seq_length

...
while True:
    model_engine.train()
    step = 0
    while step < args.steps_per_epoch:
        # to split data
        data = next(data_iter)[:, data_engine.sub_seq_start:data_engine.sub_seq_end]
        loss = model_engine(data, labels=data).loss
        model_engine.backward(loss)
        model_engine.step()
        step += 1

    epoch += 1
    new_steps = args.laststep + epoch*args.steps_per_epoch
    model_engine.save_checkpoint(f"{args.checkpoint_saving_path}",
                                    tag=f"checkpoint-{new_steps}")

其他

方法评价

引用下原文的一些其他优点:

  • 与现有系统相比,序列长度增加了 4 倍,支持训练超过百万个token的序列。
  • 与现有系统相比,通信减少了超过 10 倍,导致吞吐量提高了高达2.5倍,并且每个 GPU 的持续吞吐量超过 175 TFlops(超过硬件峰值的54%)。
  • 完全通用的 attention:DeepSpeed 序列并行支持密集和稀疏的注意力,并可与高效的注意力实现(如FlashAttention v2)一起工作。
  • 支持大规模模型训练:DeepSpeed 序列并行不仅支持大序列长度,还可以与 ZeRO-3 并用支持大模型尺寸。
  • 易于使用和迁移,最小化对现有训练框架的代码更改要求。

缺点:Ulysses也有明显缺点,就是转置后切分维度d/P,我们希望d/P=hc/P * head_size,即对head_cnt所在维度切分,这样Attention的计算都在一张卡上完成,从而可以使用FlashAttention等单卡优化。但是如果遇到GQA或者MQA情况,K、V的head_cnt很小,导致GPU数目P也不能变得很大。。

Ulysses和Ring-Attention 混合并行

Ulysses和Ring-Attention 可以组成一个混合序列并行方案。同时克服并行度<=num_head的限制,和避免P2P低效带宽利用。比如,在下面8xA100 NVLink, num_head=8上可以相比ring-flash-attn有18%和31%的训练和推理性能提升。

实现可以见:Long-Context-Attention (YunChang-云长): Distributed Attention Implementations for Long Context LLM Model Training and Inference

参考博文及介绍

[1] DeepSpeed Ulysses: System Optimizations for Enabling Training of Extreme Long Sequence Transformer Models

[2] DeepSpeed Ulysses: 训练极长序列Transformer模型的系统优化

[3] Getting Started with DeepSpeed-Ulysses for Training Transformer Models with Extreme Long Sequences

[4] 大模型训练之序列并行双雄:DeepSpeed Ulysses & Ring-Attention

[5] feifeibear/long-context-attention

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值