手把手教你高效训练256K超长上下文窗口大模型(附代码)

教程来自元象XVERSE公众号

元象发布全球首个上下文窗口长度256K的开源大模型XVERSE-Long-256K,支持输入25万汉字,让大模型应用进入“长文本时代”。参数量和高质量数据量决定了大模型的计算复杂度,而长文本技术(Long Context)是大模型应用发展“杀手锏”,因技术新、研发难度高,目前多为闭源付费提供。

长序列大模型可一次性输入很长的序列,给大模型的使用带来革命性的变化。既然其优势明显,为何市面上仅有个位数的长文本模型呢?主因有:

  • 模型训练:GPU显存的占用与序列长度的平方成正比,使训练量急剧上升。
  • 模型结构:序列越长,模型的attention越分散,模型越容易忘记前序内容。
  • 推理速度:模型序列越长,将大幅度降低模型推理速度。

元象技术路线

长文本大模型技术是在近一年内发展出来的新技术,其主要技术方案为:

直接进行长序列的预训练,但会导致训练量成平方倍的提升。

通过位置编码的插值或外推拓展序列长度,这种方法会降低位置编码的分辨率,从而降低大模型输出效果。
元象总结出一条高效拓展长序列的技术路线,很好解决了训练量大、且位置编码分辨率降低等问题。以130亿参数的基座模型为例,具体如下:

  • 第一阶段:使用ABF+继续预训练的方法,将XVERSE-13B的序列长度从8K拓展到32K,该方法可以大幅度减少预训练的训练量。

  • 第二阶段:使用NTK+SFT的方法,将序列长度从32K拓展到256K。这里的继续预训练的方法可解决前文提到的训练量激增问题,而ABF和NTK可解决模型attention衰减问题。
    元象长文本大模型训练流程

手把手训练方案

第一阶段:ABF+继续预训练

继续预训练,顾名思义,是在原先短序列预训练的基础上进行长序列的预训练。具体地说,是在XVERSE-13B(8K的短序列)的基础上,使用20%的预训练数据进行32K的长序列的继续预训练。通过少量长序列数据的继续预训练而不是从头开始的长序列预训练,可以大幅减少预训练的训练量。

ABF的全称是Adjusted Base Frequency,是将位置编码RoPE(Rotary Position Embedding)的频率从10000修改成500000。

别小看这个数字的更改,它可以大幅减少前面序列attention的衰减速度,让后面的序列更好的获取所有序列的信息。此外,通过virtual pipeline、ZeRO、不间断训练等手段提高GPU的利用率。

经过以上技术优化,技术团队仅用了一个星期就完成了XVERSE-13B的继续预训练,将序列长度从8K拓展到32K。

第二阶段:NTK+SFT

仅使用继续预训练是无法将序列长度提升到256K的,如此长度的预训练无法完成,因此技术团队使用NTK和SFT技术进一步提高序列长度。

NTK的全称是Neural Tangent Kernel,翻译为神经正切核,是一种用于理解和分析深度神经网络行为的工具。使用了NTK的RoPE可以对RoPE的频率进行动态的插值。在于保持分辨率的情况下(高频),实现了频域空间缩放(低频),从而实现位置空间的插值。

废话不多说,上NTK的代码:

class XverseRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=500000, device=None):
        super().__init__()
        self.base = base
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        inv_freq = 1.0 / \
            (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings

        t = torch.arange(self.max_seq_len_cached,
                         device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos()[
                             None, None, :, :], persistent=False)
        self.register_buffer("sin_cached", emb.sin()[
                             None, None, :, :], persistent=False)

    def forward(self, x, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:

            t = torch.arange(seq_len, device=x.device, dtype=torch.float32)
            dim = self.dim
            alpha = (seq_len / (self.max_position_embeddings/2) - 1)
            base = self.base * alpha ** (dim / (dim-2))
            ntk_inv_freq = 1.0 / \
                (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim))

            freqs = torch.einsum("i,j->ij", t, ntk_inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            cos_cached = emb.cos()[None, None, :, :]
            sin_cached = emb.sin()[None, None, :, :]
            return (
                cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
                sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype)
            )

        return (
            self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
            self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
        )

使用NTK,使模型具备了拓展更长文本的可能。接下来,通过SFT(Supervised Fine-Tuning)的方法,将序列长度进一步提升到256K。SFT是通过一问一答的方式,使得模型具备了Chat的能力并且强化NTK拓展序列长度的能力。

SFT的关键是如何生成训练数据。目前开源的长序列数据很少,更没有序列长度达到32K以上甚至256K的数据。团队使用预训练阶段使用的训练数据来构造长序列的SFT训练数据。

以构建多文档QA类数据为例:首先,基于自研的XVERSE-65B模型生成与单个文章有关的高质量问题回答对;然后将文章内容混合成目标长度的整段内容,随机选择与其中某个内容匹配的问题回答对;最后将该问题和回答作为整段内容的问题和回答,构成训练的单个样本。通过上述批量化数据生产管线,我们可以得到32K、64K,一直到256K长度的高质量对话数据。

总结一下,通过以上两个阶段的优化,将XVERSE-13B的序列长度大幅度拓展到了256K,并通过9项长文本模型评测和大海捞针等实验验证了XVERSE-13B-256K在长序列方面的强大性能。
在这里插入图片描述
在这里插入图片描述

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值