Deepmind发布新方法JEST:训练时间减少13倍,算力需求节省90%

Deepmind发布新方法JEST:训练时间减少13倍,算力需求节省90%

最近Google的人工智能团队发布了全新的数据训练方法——JEST,这种训练方法能够让训练时间减少13倍,让所消耗的算力降低90%,这无疑对AI领域是一个巨大的好消息,具体原因将在下文中具体展示。

传统的模型训练方法

首先来说一下传统的模型训练方法,一下是步骤:

一、数据准备

在训练大语言模型之前,首先需要准备训练数据。训练数据通常是大量的文本数据,这些数据可以从各种来源获取,例如新闻文章、社交媒体帖子、书籍等。数据的质量和多样性对模型的性能有很大影响,因此在选择和处理数据时需要谨慎。

1.1 数据选择

选择数据时,需要考虑数据的多样性和代表性,尽可能选择包含各种主题和风格的数据。此外,数据应该尽可能清洗和去噪,避免包含过多的错误和垃圾信息。

1.2 数据预处理

数据预处理是将原始数据转化为模型可以接受的格式的过程。这通常包括分词、去除停用词、词干提取等步骤。预处理的目的是减少模型需要处理的数据复杂性,使模型能够更好地学习文本的语义。

二、模型选择

模型选择是训练大语言模型的第二个步骤。目前,最常用的大语言模型包括Transformer、BERT、GPT等。这些模型各有优缺点,选择哪种模型取决于你的具体需求和资源。

2.1 Transformer

Transformer是一种基于自注意力机制的模型,它在处理长距离依赖问题上表现出色。然而,由于其全连接的自注意力机制,Transformer的计算复杂度较高。

2.2 BERT

BERT是基于Transformer的一个预训练模型,它通过预测句子中的缺失词来学习语言的语义。BERT在许多NLP任务上都取得了很好的效果,但其训练过程需要大量的计算资源。

2.3 GPT

GPT是另一个基于Transformer的预训练模型,它使用自回归方式学习语言模型。GPT在生成任务上表现优秀,但其只能从左到右进行预测,无法利用右侧的上下文信息。

三、训练过程

训练大语言模型的过程通常包括前向传播、损失计算、反向传播和参数更新四个步骤。这个过程需要在大量数据上反复进行,直到模型的性能达到满意的程度。

3.1 前向传播

前向传播是将输入数据送入模型,通过模型的各层计算得到预测结果的过程。

3.2 损失计算

损失计算是根据模型的预测结果和真实标签计算损失的过程。常用的损失函数包括交叉熵损失、均方误差损失等。

3.3 反向传播

反向传播是根据损失函数的梯度更新模型参数的过程。这是训练模型的关键步骤,它决定了模型学习的速度和效果。

3.4 参数更新

参数更新是将计算得到的梯度应用到模型的参数上,以改进模型的性能。

四、模型优化

模型优化是训练大语言模型的最后一个步骤,它包括模型微调、正则化、学习率调整等方法。

4.1 模型微调

模型微调是在预训练模型的基础上,对模型进行细致的调整,以适应特定任务。

4.2 正则化

正则化是一种防止模型过拟合的技术,它通过在损失函数中添加一个惩罚项来限制模型的复杂度。

4.3 学习率调整

学习率调整是一种改变模型学习速度的方法,它可以帮助模型在训练初期快速收敛,在训练后期避免过度拟合。

传统模型训练的缺点

首先是耗电量巨大、算力要求高,就拿Meta Llama3最大参数的70B模型举例,Meta用了接近100兆瓦的电力,和两个接近2.4万张V100显卡,而且Meta还计划在今年(2024)年底增加60万张H100算力基础设施。目前Llama 3的总碳排放量约为2290吨。对于目前环保的大趋势来说,肯定是非常不好的(找不出合适的词了)。

其次是训练时间长,OpenAI用了13万亿个token训练出了GPT-4,用了25000个A100训练了90到100天,而且利用率在32%到36%之间,故障数量过多也是极低利用率的原因,这会导致需要重新从之前的检查点开始训练。仅训练成本就估计有6300万美元。这还不包括所有的实验、失败的训练和其他成本,比如数据收集、RLHF、人力成本等。

全新的Deepmind的JEST训练方法

JEST是最近Google的人工智能实验室DeepMind推出的全新的模型训练方法,目的是减少模型训练的算力需求和训练时间。

下面就根据PDF来实际说说吧!https://arxiv.org/pdf/2406.17711

首先是技术细节,JEST运用以下技术:

  1. 联合样本选择算法(JEST)

    • 目标:从一个大“超级批次”(super-batch)中选择一个子批次(sub-batch),使其对学习最有用。

    • 评分机制:

      • 学习者难度(Hard Learner):选择对当前模型(学习者)损失较高的批次。公式为:
        在这里插入图片描述

        其中ℓ(B∣θ)表示批次B在模型参数θ下的损失。

      • 易参考模型(Easy Reference):选择对预训练参考模型损失较低的批次。公式为:
        在这里插入图片描述

        其中ℓ(Bθ∗)表示批次B在参考模型参数θ*下的损失。

      • 可学习性(Learnability):结合上述两者,选择对学习者损失高但对参考模型损失低的批次。公式为:
        在这里插入图片描述

    • 算法流程:

      • 初始从超级批次中随机选择一个子批次。

      • 计算当前子批次中每个样本的条件可学习性。

      • 迭代地从剩余样本中选择新的样本,直到达到预定的子批次大小。

      • 具体算法见PDF中的Algorithm 1:

        def jointly_sample_batch(learner_loss, ref_loss, n_chunks=16, filter_ratio=0.8, method="learnability"):
            scores = learner_loss - ref_loss if method == "learnability" else - ref_loss
            n_images = scores.shape[0]  # scores.shape = [B, B]
            n_draws = int(n_images * (1 - filter_ratio) / n_chunks)  # Size of each chunk.
            logits_ii = np.diag(scores)  # Self-similarity scores.
            inds = random.choice(logits_ii, n_draws)  # Sample first chunk.
            for _ in range(n_chunks - 1):
                is_sampled = np.eye(n_images)[inds].sum(axis=0)  # Binary indicator of current samples [n_images,].
                logits_ij = (scores * is_sampled.reshape(n_images, 1)).sum(axis=0)  # Negative terms ij [n_images,].
                logits_ji = (scores * is_sampled.reshape(1, n_images)).sum(axis=1)  # Negative terms ji [n_images,].
                logits = logits_ii + logits_ij + logits_ji  # Conditional learnability given past samples.
                logits = logits - is_sampled * 1e8  # Avoid sampling with replacement.
                new_inds = random.choice(n_images, n_draws, p=np.exp(logits))
                inds = np.concatenate((inds, new_inds))  # Expand the array of indices sampled.
            return inds  # Gather and return subset indices.
        
  2. 多分辨率训练

    • 方法:将子批次分成两部分,分别以不同分辨率进行编码,低分辨率部分用于加速训练。

    • 具体实现:

      • 将子批次B随机分成两部分:Blo和Bhi。

      • 低分辨率部分(Blo)使用较大的patch(如32x32)进行编码,高分辨率部分(Bhi)使用较小的patch(如16x16)进行编码。

      • 低分辨率编码结果:在这里插入图片描述

      • 高分辨率编码结果:在这里插入图片描述

      • 将两部分编码结果拼接:在这里插入图片描述

    • 算法实现:

      def loss_fn(params, params_ref, batch):
          images, texts = batch
          approx = True if cfg.method == "flexi-jest" else False
          # Score and sub-sample the initial super-batch
          embeds = model.forward(images, texts, params, approx=approx)  # [5B, D]
          embeds_ref = batch["embeds_ref"]  # Pre-cached in dataset
          if cfg.loss_type == "sigmoid":
              scores = get_scores_sigmoid(embeds, embeds_ref, params, params_ref)  # Get scores
              inds = jointly_sample_batch(scores, cfg.n_chunks, cfg.filter_ratio, cfg.learnability)
          elif cfg.loss_type == "softmax":
              inds = jointly_sample_batch_softmax(embeds_ref, embeds, params_ref, params, cfg.n_chunks, cfg.filter_ratio)  # for softmax loss, scores are re-computed in the iterative sampling.
          images, texts = stop_grad(images[inds]), stop_grad(texts[inds])  # [B, ...]
          # Split batch for co-training
          images_full, images_approx = images[::2], images[1::2]  # [B/2, ...], [B/2, ...]
          texts_full, texts_approx = texts[::2], texts[1::2]  # [B/2, ...], [B/2, ...]
          # Compute overall loss
          embeds_full = model.forward(images_full, texts_full, params, approx=False)  # [B/2, D], [B/2, D]
          embeds_approx = model.forward(images_approx, texts_approx, params, approx=approx)  # [B/2, D], [B/2, D]
          zimg = np.concatenate([embeds_full[0], embeds_approx[0]], axis=0)
          ztxt = np.concatenate([embeds_full[1], embeds_approx[1]], axis=0)
          if loss_type == "sigmoid":
              loss, _ = sigmoid_nll(params, (zimg, ztxt))
          elif loss_type == "softmax":
              loss, _, _ = softmax_nll(params, (zimg, ztxt), is_sampled=None)
          return loss
      
  3. 高效评分和多分辨率训练

    • 在线模型近似:使用低分辨率图像编码来减少计算开销。具体实现中,采用FlexiViT架构,通过降低图像分辨率来减少计算量。
    • 多分辨率训练:在训练过程中同时使用高分辨率和低分辨率图像编码,结合两者的优点。具体实现见上文算法。

利用这一点(前面的评分机制),团队采用一种基于阻塞吉布斯采样的迭代方法,逐步构建批次,每次迭代中根据条件可学习性评分选择新的样本子集。

img

与单独选择数据相比,新方法在过滤更多数据时持续改进。包括使用仅基于预训练的参考模型来评分数据也是如此,即CLIPScore,这是离线基础数据集筛选的流行基线。

img

不过,过滤更多数据会增加浮点运算次数(FLOPs),因为评分需要学习者和参考模型进行推理传递。

对此,团队在数据集中缓存了预训练的参考模型分数,他们采用了FlexiViT架构进行低分辨率评分,并在多种分辨率下进行了训练。

img

img

总而言之,相关变体JEST++和FlexiJEST++的性能显著优于许多其他先前的SOTA模型,同时使用的计算量更少。

参考:

https://www.thepaper.cn/newsDetail_forward_28022338

https://arxiv.org/pdf/2406.17711

https://new.qq.com/rain/a/20230711A05Z3Q00

https://www.huxiu.com/article/2939607.html

https://www.datalearner.com/llm-blogs/guide_to_training_large_language_models

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CodeMicheal

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值