Deepmind发布新方法JEST:训练时间减少13倍,算力需求节省90%
最近Google的人工智能团队发布了全新的数据训练方法——JEST,这种训练方法能够让训练时间减少13倍,让所消耗的算力降低90%,这无疑对AI领域是一个巨大的好消息,具体原因将在下文中具体展示。
传统的模型训练方法
首先来说一下传统的模型训练方法,一下是步骤:
一、数据准备
在训练大语言模型之前,首先需要准备训练数据。训练数据通常是大量的文本数据,这些数据可以从各种来源获取,例如新闻文章、社交媒体帖子、书籍等。数据的质量和多样性对模型的性能有很大影响,因此在选择和处理数据时需要谨慎。
1.1 数据选择
选择数据时,需要考虑数据的多样性和代表性,尽可能选择包含各种主题和风格的数据。此外,数据应该尽可能清洗和去噪,避免包含过多的错误和垃圾信息。
1.2 数据预处理
数据预处理是将原始数据转化为模型可以接受的格式的过程。这通常包括分词、去除停用词、词干提取等步骤。预处理的目的是减少模型需要处理的数据复杂性,使模型能够更好地学习文本的语义。
二、模型选择
模型选择是训练大语言模型的第二个步骤。目前,最常用的大语言模型包括Transformer、BERT、GPT等。这些模型各有优缺点,选择哪种模型取决于你的具体需求和资源。
2.1 Transformer
Transformer是一种基于自注意力机制的模型,它在处理长距离依赖问题上表现出色。然而,由于其全连接的自注意力机制,Transformer的计算复杂度较高。
2.2 BERT
BERT是基于Transformer的一个预训练模型,它通过预测句子中的缺失词来学习语言的语义。BERT在许多NLP任务上都取得了很好的效果,但其训练过程需要大量的计算资源。
2.3 GPT
GPT是另一个基于Transformer的预训练模型,它使用自回归方式学习语言模型。GPT在生成任务上表现优秀,但其只能从左到右进行预测,无法利用右侧的上下文信息。
三、训练过程
训练大语言模型的过程通常包括前向传播、损失计算、反向传播和参数更新四个步骤。这个过程需要在大量数据上反复进行,直到模型的性能达到满意的程度。
3.1 前向传播
前向传播是将输入数据送入模型,通过模型的各层计算得到预测结果的过程。
3.2 损失计算
损失计算是根据模型的预测结果和真实标签计算损失的过程。常用的损失函数包括交叉熵损失、均方误差损失等。
3.3 反向传播
反向传播是根据损失函数的梯度更新模型参数的过程。这是训练模型的关键步骤,它决定了模型学习的速度和效果。
3.4 参数更新
参数更新是将计算得到的梯度应用到模型的参数上,以改进模型的性能。
四、模型优化
模型优化是训练大语言模型的最后一个步骤,它包括模型微调、正则化、学习率调整等方法。
4.1 模型微调
模型微调是在预训练模型的基础上,对模型进行细致的调整,以适应特定任务。
4.2 正则化
正则化是一种防止模型过拟合的技术,它通过在损失函数中添加一个惩罚项来限制模型的复杂度。
4.3 学习率调整
学习率调整是一种改变模型学习速度的方法,它可以帮助模型在训练初期快速收敛,在训练后期避免过度拟合。
传统模型训练的缺点
首先是耗电量巨大、算力要求高,就拿Meta Llama3最大参数的70B模型举例,Meta用了接近100兆瓦的电力,和两个接近2.4万张V100显卡,而且Meta还计划在今年(2024)年底增加60万张H100算力基础设施。目前Llama 3的总碳排放量约为2290吨。对于目前环保的大趋势来说,肯定是非常不好的(找不出合适的词了)。
其次是训练时间长,OpenAI用了13万亿个token训练出了GPT-4,用了25000个A100训练了90到100天,而且利用率在32%到36%之间,故障数量过多也是极低利用率的原因,这会导致需要重新从之前的检查点开始训练。仅训练成本就估计有6300万美元。这还不包括所有的实验、失败的训练和其他成本,比如数据收集、RLHF、人力成本等。
全新的Deepmind的JEST训练方法
JEST是最近Google的人工智能实验室DeepMind推出的全新的模型训练方法,目的是减少模型训练的算力需求和训练时间。
下面就根据PDF来实际说说吧!https://arxiv.org/pdf/2406.17711
首先是技术细节,JEST运用以下技术:
-
联合样本选择算法(JEST):
-
目标:从一个大“超级批次”(super-batch)中选择一个子批次(sub-batch),使其对学习最有用。
-
评分机制:
-
学习者难度(Hard Learner):选择对当前模型(学习者)损失较高的批次。公式为:
其中ℓ(B∣θ)表示批次B在模型参数θ下的损失。
-
易参考模型(Easy Reference):选择对预训练参考模型损失较低的批次。公式为:
其中ℓ(B∣θ∗)表示批次B在参考模型参数θ*下的损失。
-
可学习性(Learnability):结合上述两者,选择对学习者损失高但对参考模型损失低的批次。公式为:
-
-
算法流程:
-
初始从超级批次中随机选择一个子批次。
-
计算当前子批次中每个样本的条件可学习性。
-
迭代地从剩余样本中选择新的样本,直到达到预定的子批次大小。
-
具体算法见PDF中的Algorithm 1:
def jointly_sample_batch(learner_loss, ref_loss, n_chunks=16, filter_ratio=0.8, method="learnability"): scores = learner_loss - ref_loss if method == "learnability" else - ref_loss n_images = scores.shape[0] # scores.shape = [B, B] n_draws = int(n_images * (1 - filter_ratio) / n_chunks) # Size of each chunk. logits_ii = np.diag(scores) # Self-similarity scores. inds = random.choice(logits_ii, n_draws) # Sample first chunk. for _ in range(n_chunks - 1): is_sampled = np.eye(n_images)[inds].sum(axis=0) # Binary indicator of current samples [n_images,]. logits_ij = (scores * is_sampled.reshape(n_images, 1)).sum(axis=0) # Negative terms ij [n_images,]. logits_ji = (scores * is_sampled.reshape(1, n_images)).sum(axis=1) # Negative terms ji [n_images,]. logits = logits_ii + logits_ij + logits_ji # Conditional learnability given past samples. logits = logits - is_sampled * 1e8 # Avoid sampling with replacement. new_inds = random.choice(n_images, n_draws, p=np.exp(logits)) inds = np.concatenate((inds, new_inds)) # Expand the array of indices sampled. return inds # Gather and return subset indices.
-
-
-
多分辨率训练:
-
方法:将子批次分成两部分,分别以不同分辨率进行编码,低分辨率部分用于加速训练。
-
具体实现:
-
将子批次B随机分成两部分:Blo和Bhi。
-
低分辨率部分(Blo)使用较大的patch(如32x32)进行编码,高分辨率部分(Bhi)使用较小的patch(如16x16)进行编码。
-
低分辨率编码结果:
-
高分辨率编码结果:
-
将两部分编码结果拼接:
-
-
算法实现:
def loss_fn(params, params_ref, batch): images, texts = batch approx = True if cfg.method == "flexi-jest" else False # Score and sub-sample the initial super-batch embeds = model.forward(images, texts, params, approx=approx) # [5B, D] embeds_ref = batch["embeds_ref"] # Pre-cached in dataset if cfg.loss_type == "sigmoid": scores = get_scores_sigmoid(embeds, embeds_ref, params, params_ref) # Get scores inds = jointly_sample_batch(scores, cfg.n_chunks, cfg.filter_ratio, cfg.learnability) elif cfg.loss_type == "softmax": inds = jointly_sample_batch_softmax(embeds_ref, embeds, params_ref, params, cfg.n_chunks, cfg.filter_ratio) # for softmax loss, scores are re-computed in the iterative sampling. images, texts = stop_grad(images[inds]), stop_grad(texts[inds]) # [B, ...] # Split batch for co-training images_full, images_approx = images[::2], images[1::2] # [B/2, ...], [B/2, ...] texts_full, texts_approx = texts[::2], texts[1::2] # [B/2, ...], [B/2, ...] # Compute overall loss embeds_full = model.forward(images_full, texts_full, params, approx=False) # [B/2, D], [B/2, D] embeds_approx = model.forward(images_approx, texts_approx, params, approx=approx) # [B/2, D], [B/2, D] zimg = np.concatenate([embeds_full[0], embeds_approx[0]], axis=0) ztxt = np.concatenate([embeds_full[1], embeds_approx[1]], axis=0) if loss_type == "sigmoid": loss, _ = sigmoid_nll(params, (zimg, ztxt)) elif loss_type == "softmax": loss, _, _ = softmax_nll(params, (zimg, ztxt), is_sampled=None) return loss
-
-
高效评分和多分辨率训练:
- 在线模型近似:使用低分辨率图像编码来减少计算开销。具体实现中,采用FlexiViT架构,通过降低图像分辨率来减少计算量。
- 多分辨率训练:在训练过程中同时使用高分辨率和低分辨率图像编码,结合两者的优点。具体实现见上文算法。
利用这一点(前面的评分机制),团队采用一种基于阻塞吉布斯采样的迭代方法,逐步构建批次,每次迭代中根据条件可学习性评分选择新的样本子集。
与单独选择数据相比,新方法在过滤更多数据时持续改进。包括使用仅基于预训练的参考模型来评分数据也是如此,即CLIPScore,这是离线基础数据集筛选的流行基线。
不过,过滤更多数据会增加浮点运算次数(FLOPs),因为评分需要学习者和参考模型进行推理传递。
对此,团队在数据集中缓存了预训练的参考模型分数,他们采用了FlexiViT架构进行低分辨率评分,并在多种分辨率下进行了训练。
总而言之,相关变体JEST++和FlexiJEST++的性能显著优于许多其他先前的SOTA模型,同时使用的计算量更少。
参考:
https://www.thepaper.cn/newsDetail_forward_28022338
https://arxiv.org/pdf/2406.17711
https://new.qq.com/rain/a/20230711A05Z3Q00
https://www.huxiu.com/article/2939607.html
https://www.datalearner.com/llm-blogs/guide_to_training_large_language_models