序列并行是用于训练超长文本(64k, 128k等)的技术。一定长度上也可以扩展大模型的长文本能力。
Transformer激活值的显存计算
从上述链接中可以获知, l l l层transformers的中间激活值占用的显存为: ( 34 b s h + 5 b s 2 a ) ∗ l (34bsh+5bs^{2}a)*l (34bsh+5bs2a)∗l,其中:
- b b b: batch size
- s s s: 序列长度
- h h h: 隐藏层维度
- a a a: 注意力头的个数
对于qwen1.5-32b-chat模型来说,
h
=
4096
,
a
=
8
(
G
Q
A
)
,
l
=
64
h=4096, a = 8(GQA), l=64
h=4096,a=8(GQA),l=64,假设用batch size=1,训练长度为32k的序列。那么可以得到中间
l
l
l层的激活值为:
(
34
∗
1
∗
32000
∗
4096
+
5
∗
1
∗
32000
∗
32000
∗
8
)
∗
64
/
1000
/
1000
/
1000
=
2906
G
(34*1*32000*4096+5*1*32000*32000*8) * 64 / 1000 / 1000/ 1000 = 2906G
(34∗1∗32000∗4096+5∗1∗32000∗32000∗8)∗64/1000/1000/1000=2906G
即使是中间某一层的激活值,也有45G左右。这些激活值是需要存储在显存中以便反向传播求梯度的,开启gradient_checkpointing之后可以只存某些激活值,可以显著降低激活值的显存占用。
但是我们在qwen1.5-32b上尝试,即使开启gradient_checkpointing+deepspeed_zero3,在24卡A100上训练32k文本,也会在某一层layer计算attention score的时候报显存不足的错误,这种情况下并不是因为GPU卡太少,而是峰值显存占用量过大。
回顾Transformer Attention的计算过程
上图是attention的计算过程,分析下 a t t n o u t p u t attn_{output} attnoutput的占用显存量,以qwen1.5-32b为例, b s = 1 , n h = 8 ( G Q A ) , S = 32000 bs=1,nh=8(GQA),S=32000 bs=1,nh=8(GQA),S=32000,则 a t t n o u t p u t attn_{output} attnoutput占用的显存量为: 8 ∗ 32000 ∗ 32000 ∗ 2 / 1000 / 1000 / 1000 = 16 G 8*32000*32000*2/1000/1000/1000=16G 8∗32000∗32000∗2/1000/1000/1000=16G。若每一层都存储这样的矩阵,单张A100肯定是存不下的。必须将 a t t n o u t p u t attn_{output} attnoutput切分,放到不同的卡上。
仔细观察可以发现,我们只能从 a t t n o u t p u t attn_{output} attnoutput的 n h nh nh维度入手,将其切分为不同的块,放到不同的gpu上进行运算。
我们已经确定了, a t t n o u t p u t attn_{output} attnoutput只能在 n h nh nh的维度进行切分。按照自底向上(按照红色的标号)的顺序进行分析,我们可以发现,要达到这一目的,其实有多种方法:
- 对标号为4处的tensor的第二维进行切分。
- 对标号为2处的tensor的第一维进行切分。
- 对标号为3处的tensor的第二维进行切分。
但是以上3中方法并不是xtuner/deepspeed-uly中使用的方法。这两个包是在标号为5的tensor,也就是在 h i d d e n _ s t a t e s hidden\_states hidden_states上,在 S S S维度进行了切分,后面再通过变换,变换到 n h nh nh维度进行切分。这样可以进一步降低显存占用。
sequence parallel(序列并行)
序列并行如上图所示,sequence parallel size表示将
a
t
t
n
o
u
t
p
u
t
attn_{output}
attnoutput放到几张卡上,设为
s
p
sp
sp, 这
s
p
sp
sp张卡称为一个
s
p
_
g
r
o
u
p
sp\_group
sp_group
在计算
q
s
t
a
t
e
s
,
k
s
t
a
t
e
s
,
v
s
t
a
t
e
s
q_{states}, k_{states}, v_{states}
qstates,kstates,vstates的过程中,我们需要将长度为
S
S
S序列,切分为
s
p
sp
sp份,每份长度
S
/
s
p
S/sp
S/sp,从position_ids的角度讲,
- s p _ r a n k = 0 sp\_rank=0 sp_rank=0计算的position_ids的范围是 [ 0 , S / s p − 1 ] [0, S/sp-1] [0,S/sp−1],
- s p _ r a n k = 1 sp\_rank=1 sp_rank=1计算的position_ids的范围是 [ S / s p , 2 S / s p − 1 ] [S/sp, 2S/sp-1] [S/sp,2S/sp−1],
- s p _ r a n k = s p − 1 sp\_rank=sp-1 sp_rank=sp−1计算的position_ids的范围是 [ S ( s p − 1 ) / s p , S − 1 ] [S(sp-1)/sp,S-1] [S(sp−1)/sp,S−1]
在进入 a t t e n t i o n _ f o r w a r d attention\_forward attention_forward计算 a t t n o u t p u t attn_{output} attnoutput之前,将切分维度从 S S S变换到 n h nh nh,这样每张卡上存储的 a t t n _ o u t p u t attn\_{output} attn_output的显存占用量只有原来的 1 / s p 1/sp 1/sp,分担了 n h / s p nh/sp nh/sp个注意力头的计算。但是在进行后续的计算之前,还需要将 a t t n _ o u t p u t attn\_{output} attn_output的维度由 [ b s , n h / s p , S , S ] [bs, nh/sp, S, S] [bs,nh/sp,S,S]变换到 [ b s , n h , S / s p , S / s p ] [bs, nh, S/sp, S/sp] [bs,nh,S/sp,S/sp]。
DistributedSampler
在分布式计算时,需要用到DistributedSampler,该类用于对每张卡进行样本的分配。
class DistributedSampler(Sampler[T_co]):
r"""Sampler that restricts data loading to a subset of the dataset.
It is especially useful in conjunction with
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
:class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
original dataset that is exclusive to it.
.. note::
Dataset is assumed to be of constant size and that any instance of it always
returns the same elements in the same order.
Args:
dataset: Dataset used for sampling.
num_replicas (int, optional): Number of processes participating in
distributed training. By default, :attr:`world_size` is retrieved from the
current distributed group.
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
By default, :attr:`rank` is retrieved from the current distributed
group.
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
indices.
seed (int, optional): random seed used to shuffle the sampler if
:attr:`shuffle=True`. This number should be identical across all
processes in the distributed group. Default: ``0``.
drop_last (bool, optional): if ``True``, then the sampler will drop the
tail of the data to make it evenly divisible across the number of
replicas. If ``False``, the sampler will add extra indices to make
the data evenly divisible across the replicas. Default: ``False``.
.. warning::
In distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
Example::
>>> # xdoctest: +SKIP
>>> sampler = DistributedSampler(dataset) if is_distributed else None
>>> loader = DataLoader(dataset, shuffle=(sampler is None),
... sampler=sampler)
>>> for epoch in range(start_epoch, n_epochs):
... if is_distributed:
... sampler.set_epoch(epoch)
... train(loader)
"""
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
rank: Optional[int] = None, shuffle: bool = True,
seed: int = 0, drop_last: bool = False) -> None:
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
if rank >= num_replicas or rank < 0:
raise ValueError(
f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
self.seed = seed
def __iter__(self) -> Iterator[T_co]:
if self.shuffle:
# deterministically shuffle based on epoch and seed
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
else:
indices = list(range(len(self.dataset))) # type: ignore[arg-type]
if not self.drop_last:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
def __len__(self) -> int:
return self.num_samples
def set_epoch(self, epoch: int) -> None:
r"""
Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas
use a different random ordering for each epoch. Otherwise, the next iteration of this
sampler will yield the same ordering.
Args:
epoch (int): Epoch number.
"""
self.epoch = epoch
在init函数中,可以看到有num_replicas和rank:
- num_replicas:表示总共有几张卡
- rank: 表示是第几张卡(global rank)。
比如,用3个node,每个node 8张A100,那么对于第二个node的第三张卡:num_replicas=24, rank=10。
从代码中可以看到,pytorch官方实现的DistributedSampler,对于每张卡,分配的是不同的样本。前面提到,sequence parallel需要对同一个sp_group中的卡,分配相同的样本,之后对该样本切分成sp份后,每张卡得到对应的一份。因此,也就需要重新实现DistributedSampler。
一个例子
假设有3个node,每个node 8张A100,sequence parallel size=4,那么我们可以得到6个sp_group,每个sp_group完成一条样本的计算。也就是说,一次forward,只能计算6个样本(对于超长文本来说,batch_size只能设置为1)。
但是在计算梯度的时候,DDP模式是会除以总卡数的,这样是6个样本的平均梯度相比于24个样本的平均梯度必然会小。因此需要将
gradient_accumulation_step
设置为sequence parallel size
。