使用 DeepSpeed 和 Hugging Face Transformer 微调 FLAN-T5 XL/XXL

Scaling Instruction-Finetuned Language Models 论文发布了 FLAN-T5 模型,它是 T5 模型的增强版。FLAN-T5 由很多各种各样的任务微调而得,因此,简单来讲,它就是个方方面面都更优的 T5 模型。相同参数量的条件下,FLAN-T5 的性能相比 T5 而言有两位数的提高。Google 在 Hugging Face 上开源了 5 个 FLAN-T5 的 checkpoints,参数量范围从 8000 万  到 110 亿。

  • Scaling Instruction-Finetuned Language Models 论文地址:
    https://arxiv.org/pdf/2210.11416.pdf

  • 关于 FLAN-T5 的模型筛选结果:
    https://hf.co/models?other=arxiv:2210.11416

在之前的一篇博文中,我们已经学习了如何 针对聊天对话数据摘要生成任务微调 FLAN-T5,那时我们使用的是 Base (250M 参数) 模型。本文,我们将研究如何将训练从 Base 扩展到 XL (30 亿参数) 或 XXL (110 亿参数)。

  • 针对聊天对话数据摘要生成任务微调 FLAN-T5 指南:
    https://www.philschmid.de/fine-tune-flan-t5

  • Base (250M 参数) 模型:
    https://hf.co/google/flan-t5-base

  • XL (30 亿参数) 模型:
    https://hf.co/google/flan-t5-xl

  • XXL (110 亿参数) 模型:
    https://hf.co/google/flan-t5-xxl

这意味着我们将学习如何利用模型并行、多 GPU 以及 DeepSpeed ZeRO 来微调 FLAN-T5 XL 和 XXL。

  • DeepSpeed ZeRO 链接:
    https://www.deepspeed.ai/tutorials/zero/

除了作为教程的部分之外,我们还跑了一系列实验,这些实验数据可以帮助你选择正确的硬件设置。你可以在 结果和实验 部分找到详细信息。

# install git lfs for pushing artifacts
!sudo apt install git-lfs
# install torch with the correct cuda version, check nvcc --version
!pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 --upgrade
# install Hugging Face Libraries
!pip install "transformers==4.26.0" "datasets==2.9.0" "accelerate==0.16.0" "evaluate==0.4.0" --upgrade
# install deepspeed and ninja for jit compilations of kernels
!pip install "deepspeed==0.8.0" ninja --upgrade
# install additional dependencies needed for training
!pip install rouge-score nltk py7zr tensorboard

处理数据集

与 针对聊天对话的摘要生成任务微调 FLAN-T5 一文中类似,我们需要先准备一个用于微调的数据集。本文,我们将在 CNN Dailymail 数据集 上微调 FLAN-T5-XXL。我们不会赘述如何生成数据集,如果你想了解数据集生成的详细步骤,请参阅前文提到的 Fine Tune FLAN-T5。

  • CNN Dailymail 数据集:
    https://hf.co/datasets/cnn_dailymail

  • FLAN-T5-XXL:
    https://hf.co/google/flan-t5-xxl

我们定义了一些参数,本文的示例都会基于这些参数,但你可以根据实际需要进行调整。

# 实验配置
model_id = "google/flan-t5-xxl" # Hugging Face 模型 Id
dataset_id = "cnn_dailymail" # Hugging Face 数据集 Id
dataset_config = "3.0.0" # 数据集版本
save_dataset_path = "data" # 存放处理后数据的本地路径
text_column = "article" # 输入文本所属列
summary_column = "highlights" # 输出文本所属列
# 定制指令提示格式
prompt_template = f"Summarize the following news article:\n{
   {input}}\nSummary:\n"

与 Fine Tune FLAN-T5 不同,这次我们把预处理和训练分开。这样我们就可以在非 GPU

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值