一文速览Gemma及其微调(第5.2版):通过我司七月的早期paper-7方面review数据集微调Gemma2

前言

如此文《七月论文审稿GPT第3.2版和第3.5版:通过paper-review数据集分别微调Mistral、gemma》所讲

Google作为曾经的AI老大,我司自然紧密关注,所以当Google总算开源了一个gemma 7b,作为有技术追求、技术信仰的我司,那必须得支持一下,比如用我司的paper-review数据集微调试下,彰显一下gemma的价值与威力

后来我司考虑到毕竟llama的生态更完善、迭代速度更快,故之后更多是微调llama,然后Google到底是不甘落后,24年6.27,在时隔4个月之后,Google终于推出了gemma的升级版:gemma2

我其实想说,如果是前几年的AI时代,这个速度可以了,但如今是大模型时代,还是太慢了(毕竟llama已到3,Claude则已到3.5)

  1. 不过既然推出了,加之我司把论文审稿的数据弄成7方面review之后,llama2、llama3都还没pk赢过gpt4(4方面review下微调llama2,早已赢过了GPT4-1106;7方面review下微调llama3却只和GPT4打成平手)
  2. 如此,可以让情况4的早7数据 微调下gemma2,过程中保持:“微调的prompt用的是potential,与阿荀给的训练数据中的摘要prompt的格式一致”,包括推理的prompt

预期是:开源模型得在7方面review下微调后的表现,类似4方面review那样,也是可以超过gpt4的(功夫不负有心人的是,终于超了,详见本文文末,也预示着我司审稿团队开发整整一年的七月论文审稿GPT 达到了对外发布上线的标准了,太不容易了..)

如此,便有了本文(且把之前关于gemma1的介绍也从上面那篇文章 《通过paper-review数据集分别微调Mistral、gemma》中脱离出来,归纳到本文)

第一部分 Google推出gemma,试图与llama、Mistral形成三足鼎立之势

Google在聊天机器人这个赛道上,可谓被双向夹击

  • 闭源上被OpenAI的ChatGPT持续打压一年多(尽管OpenAI用的很多技术比如transformer、CoT都是Google发明的,尽管Google推出了强大的Gemini)
  • 开源上则前有Meta的llama,后有Mistral的来势汹汹

终于在24年2.21,按耐不住推出了开源模型gemma(有2B、7B两个版本,这是其技术报告这是其解读之一),试图对抗与llama、Mistral在开源场景上形成三足鼎立之势

1.1 gemma 7B的性能:比肩Mistral 7B、超越llama 7B

Gemma 7B在 18 个基于文本的任务中的 11 个上优于相似参数规模的开放模型,例如除了问答上稍逊于llama 13B,其他诸如常识推理、数学和科学、编码等任务上的表现均超过了llama2 7B/13B(关于llama2的介绍请看此文的第三部分)、Mistral 7B

1.2 模型架构:基于Transformer解码器、多头/多查询注意力、RoPE、GeGLU、RMSNorm

1.2.1 基于Transformer解码器:上下文8192、词表256K、训练数据集6万亿token

Gemma 模型架构基于 Transformer 解码器

  1. 模型训练的上下文长度为 8192 个 token
  2. 其词表则比llama2 所用的32K大太多了,为 256k个token(导致我们微调gemma 7b时,在论文审稿所需要的理想长度12K之下且在已经用了qlora和flash attention的前提之下,48g显存都不够,详见此文)
  3. 至于训练数据集达到6万亿个token(即We trained Gemma models on up to 6T tokens of text,而llama2的训练集只有2万亿个token)

1.2.2 7B/2B模型使用多头注意力/多查询注意力、RoPE、GeGLU、RMSNorm

此外,gemma还在原始 transformer 论文的基础上进行了改进,改进的部分包括:

  • 多查询注意力:7B 模型使用多头注意力(即MHA,如下图最左侧所示),而 2B 检查点使用多查询注意力「即MQA,如下图最右侧所示,𝑛𝑢𝑚_𝑘𝑣_ℎ𝑒𝑎𝑑𝑠 = 1,更多介绍请参见《一文通透各种注意力:从多头注意力MHA到分组查询注意力GQA、多查询注意力MQA

  • RoPE 嵌入:Gemma 在每一层中使用旋转位置嵌入,而不是使用绝对位置嵌入;此外,Gemma 还在输入和输出之间共享嵌入,以减少模型大小
  • GeGLU 激活:GeGLU 激活函数(其对应的论文为Google发的这篇《GLU Variants Improve Transformer),取代传统的 ReLU 非线性函数

    GeGLU是GeLU(Gaussian Error Linear Unit) 的门线性单元变体,而GeLU与ReLU不允许负值不同,其允许为负输入值执行梯度传播

    ​总之,GeGLU 的激活被分割为两部分,分别是 sigmoid 单元和线性映射单元(它与 sigmoid 单元的输出逐元素相乘),使得其与Llama 2和Mistral等用的SwiGLU极其类似(关于SwiGLU的细致介绍请看此文《LLaMA的解读与其微调:Alpaca-LoRA/Vicuna/BELLE/中文LLaMA/姜子牙/LLaMA 2》的1.2.3节:SwiGLU替代ReLU)

    唯一的区别是 GeGLU 使用的基础激活函数是 GeLU 而不是 Swish

  • 归一化:Gemma 对每个 transformer 子层的输入和输出进行归一化,这与仅对其中一个或另一个进行归一化的标准做法有所不同,另,gemma使用RMSNorm 作为归一化层

    如国外一开发者Sebastian Raschka所说,“At first glance, it sounds like Gemma has an additional RMSNorm layer after each transformer block. However, looking at the official code implementation, it turns out that Gemma just uses the regular pre-normalization scheme that is used by other LLMs like GPT-2, Llama 2(Gemma 仅仅使用了 GPT-2、Llama 2 等其他 LLM 使用的常规预归一化方案), and so on, as illustrated below

1.2.3 预训练、指令调优、RLHF、监督微调

对于 7B 模型,谷歌在 16 个pod(共计4096 个TPUv5e)上训练模型,他们通过 2 个pod对2B模型进行预训练,总计 512 TPUv5e

在一个 pod 中,谷歌对 7B 模型使用 16 路模型分片和 16 路数据复制,对于 2B 模型,只需使用 256 路数据复制

优化器状态使用类似 ZeRO-3 的技术进一步分片。在 pod 之外,谷歌使用了 Pathways 方法通过数据中心网络执行数据复制还原

  • 预训练
    Gemma 2B 和 7B 分别在来自网络文档、数学和代码的 2T 和 6T 主要英语数据上进行训练。与 Gemini 不同的是,这些模型不是多模态的,也不是为了在多语言任务中获得最先进的性能而训练的
    为了兼容,谷歌使用了 Gemini 的 SentencePiece tokenizer 子集。它可以分割数字,不删除多余的空白,并对未知 token 进行字节级编码
  • 指令调优与RLHF
    谷歌通过在仅文本、仅英语合成和人类生成的 prompt 响应对的混合数据上进行监督微调即SFT,以及利用在仅英语标记的偏好数据和基于一系列高质量 prompt 的策略上训练的奖励模型进行人类反馈强化学习即RLHF,对 Gemma 2B 和 Gemma 7B 模型进行微调

    具体而言
    \rightarrow​  gemma根据基于 LM 的并行评估结果来选择自己的混合数据,以进行监督微调。给定一组留出的(heldout) prompt, 让测试模型生成response,并让基线模型生成相同prompt下的response,然后让规模更大的高性能模型来预测哪个response更符合人类的偏好
    \rightarrow​  gemma还构建不同的 prompt 集来突出特定的能力,例如指令遵循、真实性、创造性和安全性等。gemma使用了不同的自动化LM裁判,它们采用了多种技术,比如思维链提示、对齐人类偏好等

第二部分 gemma2

24年6.27,Google终于推出了gemma的升级版:gemma2,至于在参数规模上,先推出了9B和27B的,2B的则待发布(这是其技术报告,这是其技术blog)

2.1 同是8192长度/RoPE/GeGLU/RMSNorm,但用了滑窗注意力/GQA/Logit软上限

总的来说,gemma2与gemma1的相同之处是的长下文长度也是8192个token,也使用旋转位置编码、GeGLU非线性函数,以及使用RMSNorm来归一化每个transformer子层、注意力层和前馈层的输入和输出

不同之处则在于

  1. gemma2在每一层交替使用局部滑动窗口注意力与全局注意力「We alternate between a local sliding window attention (Beltagy et al., 2020a,b) and global attention (Luong et al., 2015) in every other laye
    其中,局部注意力层的滑动窗口大小设置的4096,而全局注意力层的跨度则设置的8192
  2. 9B和27B模型都使用分组查询注意力GQA(在其技术报告中说的是,之所以选择 GQA,是因为它需要更少的参数,并且在推理时更快),其中num_groups = 2(如下图中部所示)

    为方便大家一目了然,我用一个表再总结归纳如下
    多头注意力MHA分组查询注意力GQA(逐渐成主流)多查询注意力MQA
    LLaMA2(仅34B 70B)、Llama 3(8B 70B)ChatGLM2
    Mistral 7BGoogle Gemini
    Google gemma1 7BGoogle gemma2(包含9B 27B)Google gemma1 2B
  3. Logit软上限:Logit soft-capping
    遵循Gemini 1.5(Gemini Team, 2024),gemma2在每个注意力层和最后一层中对logit进行软上限限制,使得logit的值保持在-soft_cap和+soft_cap之间
    更具体地说,我们设置logit为:

    对于9B和27B模型,他们将注意力logit限制在50.0,最终logit限制在30.0
    不过,考虑到gemma2的这个特性「注意力logit软上限」与常见的FlashAttention实现不兼容,故他们在针对“使用了FlashAttention的HuggingFace transformers库和vLLM实现库”时移除了这个特性,即废除了Logit软上限这个特性,说的不客气点,就是个鸡肋

总之,以下是具体的模型参数

参数/设计选择2.6B9B27B
模型尺寸d_model230435844608
层数Layers264246
前归一化pre-norm
后归一化post-norm
非线性激活函数
Non-linearity
GeGLUGeGLUGeGLU
前馈维数
Feedforward dim
184322867273728
头部类型Head typeGQAGQAGQA
头部数量Num heads 81632
KV头部数量
Num KV heads
4816
头部大小256256128
全局注意力跨度global att. span819281928192
滑动窗口sliding window409640964096
词汇表大小Vocab size256128256128256128
Tied embedding

2.2 预训练:训练数据、分词器、过滤器、知识蒸馏

2.2.1  gemma2的训练数据:9B 8万亿

gemma2的各个不同参数规模模型所用的训练数据自然不同,具体如下所示(过程中,进一步使用类似于ZeRO3的技术进行分片)

  • Gemma2 27B:在主要为英语数据的13万亿个tokens上训练
    在8x24x32配置的TPUv5p上训练,总计6144个芯片,具有768路数据复制和8路模型分片
  • Gemma2 9B模型:在8万亿个tokens上训练
    在8x16x32配置的TPUv4上训练,总计4096个芯片,具有1024路数据复制和4路模型分片
  • Gemma2 2.6B模型:在2万亿个tokens上训练
    在2x16x16配置的TPUv5e上训练,总计512个芯片,具有512路数据复制和1路模型分片

这些tokens来自多种数据源,包括网页文档、代码和科学文章。 我们的模型不是多模态的,也不是专门为最先进的多语言能力而训练的,而最终的数据混合是通过类似于Gemini 1.0的消融方法确定的

  • 分词器
    使用与Gemma 1和Gemini相同的分词器:一个带有拆分数字、保留空白和字节级编码的SentencePiece分词器,生成的词汇表有256k个条目
  • 过滤
    使用与Gemma 1相同的数据过滤技术。 具体来说,我们过滤预训练数据集以减少不需要或不安全的言论的风险,过滤掉某些个人信息或其他敏感数据

2.2.2 知识蒸馏:教师模型带学生模型

给定一个用作教师的大模型,通过蒸馏教师对每个token x在其上下文x_{c}中给出的概率来训练较小的模型,即P_{T}\left(x \mid x_{c}\right)

更准确地说,最小化教师模型和学生模型在各自计算出来的概率之间的负对数似然:

\min _{P_{S}} \sum_{x}-P_{T}\left(x \mid x_{c}\right) \log P_{S}\left(x \mid x_{c}\right)

其中P_{S}是学生模型的参数化概率。 在实践中,我们对教师模型进行一次推理并存储概率。 由于词汇表有265k个条目,故只存储教师概率的采样子集

2.3 微调:先SFT后RLHF

为了得到可以遵循指令的微调模型

  1. 首先,我们在一组仅文本、仅英语的合成和人工生成的prompt-response对上应用监督微调(SFT)
  2. 然后,我们在这些模型的基础上应用RLHF,奖励模型在标注的仅限英语的偏好数据上训练,策略基于与SFT阶段相同的prompt
  3. 最后,我们对每个阶段获得的模型进行平均,以提高它们的整体性能
    最终的数据混合和后训练方案,包括调整后的超参数,都是基于在提高有用性的同时最小化模型危害的基础上选择

2.3.1 通过教师模型做SFT训练

在合成和真实prompt上运行行为克隆,response主要由教师(即更大的模型)合成生成,且还在学生模型的分布上运行教师的蒸馏

2.3.1 RLHF:从人类反馈中进行强化学习

使用与Gemma 1类似的RLHF算法,但使用了不同的奖励模型,该模型比策略大一个数量级。新的奖励模型也更倾向于对话能力,特别是多轮对话 

且We average models from experiments run with different hyperparameters (Ramé et al., 2024

最终效果上还可以的,27B模型在一些测试基准上所表现出来的能力也能接近llama3 70B

第三部分 七月论文审稿GPT第5.2版:通过早期paper-7方面review数据集微调gemma2

3.1 gemma2微调环境配置

3.1.1 事先准备

  • Linux系统
  • 支持cuda12.1
  • 单张/多张A100 80G显卡
  • 可访问HuggingFace/Python官方源的网络代理

3.1.2 模型下载

# 安装 git-lfs
curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash
sudo apt-get install git-lfs
git lfs install

# 下载模型
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/google/gemma-2-9b-it
cd Gemma-2-9b-it
git lfs pull --include="*.safetensors"
cd ..

3.1.3 环境安装

# 安装虚拟环境
conda create -n llm python=3.10            --- 按y
# 激活虚拟环境
source activate llm
或者
conda activate llm

# pip 安装依赖
pip install -r requirement_new.txt

# 单独安装flash attention
# 安装flash attention 前最好保证linux "外界"的 cuda版本 与 conda 虚拟环境中cuda版本一致(不一致可能导致使用过程中报"不匹配"的错误)
pip install packaging
pip install flash-attn==2.5.7 --no-build-isolation          --- flashatten编译过程需要一定的时间,需要等待

# 如果硬盘空间小可以选择删除pip包的缓冲(删除后安装需重新下)
pip cache purge

此外,transformers库需要安装较新版本,使用下面的命令安装:

pip install git+https://github.com/huggingface/transformers
# 本次微调使用的版本为4.42.3

3.2 gemma2微调过程

上文提到,我们微调gemma2可能需要准备一张或多张显存为80G的A100的显卡,原因为:

  • 如上文2.1节所述,截止2024年7月5日,gemma2中的logit soft caping机制在sdpa和flash attention v2并不支持
    本次使用eager attention 进行微调的,而eager attention没有进行内存和计算优化
  • huggingface gemma2的SWA机制在eager attention中通过attention_mask的模拟实现的
    即本质来说还是标准的self-attention结构,其仅仅实现了SWA的结果,并没有实现SWA减少计算复杂度的目的,在微调过程占用显存较高
  • gemma2 相比较 llama3 来说,虽说QKV的数量减少,但模型的词汇表和模型的层数多了很多,微调的时候可能需要保存更多的中间激活值,显存的开销较大

项目组同事青睐在微调gemma2的过程中,使用了SWA交替层attention和logit soft caping机制,但推理侧分别评估了是否开启logit soft caping机制时推理的性能差距,详见下文的3.3节

3.2.1 微调参数改动(主要参数)

参数

说明

batch size=16

梯度累计总batch size=16

lr=1e-4

学习率的大小

max_prompt_length=11138

paper 最长的大小,超过将被截取

max_response_length=1150

review 最长的大小,超过将被截取

save_steps=100

迭代100次保存一次模型

num_train_epoch=3

迭代3个epoch

此外

参数

说明

max_seq_length=10865

paper + review 最大的长度

max_prompt_length=9715

paper 最长的大小,超过将被截取

max_response_length=1150

review 最长的大小,超过将被截取

3.2.2 代码细节改动:情况4下的微调system prompt与dataset改动

3.2.2.1 情况4 微调system prompt

gemma2微调的system prompt与llama3情况4的prompt基本一致,仅最后一行将“.”改为“:”

但需要注意的是gemma2默认的指令模版中没有system prompt这个选项,故system prompt与“user” 拼接在一起使用,具体见下文

SYSTEM_PROMPT = """Below is an "Instruction" that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
Instruction:
You are a professional machine learning conference reviewer who reviews a given paper and considers 7 criteria: 
** How to evaluate the idea of the paper **
** Compared to previous similar works, what are the essential differences, such as any fundamental differences, improvements, innovations **
** How to evaluate the experimental results in the paper **
** Potential reasons for acceptance **
** Potential reasons for rejection **
** Other suggestions for further improving the quality of the paper **
** Other important review comments **
The given paper is as follows:"""
3.2.2.2 dataset 改动

这部分的代码详见七月官网首页的:大模型项目开发线上营2期

3.2.2.3 gemma2 PI 长度扩展实现

继承 Gemma2RotaryEmbedding,新增

  1. Gemma2LinearScalingRotaryEmbedding类支持PI
  2. Gemma2DynamicNTKScalingRotar yEmbedding类支持动态NTK
  3. 在Gemma2Attention中使用self._init_rope进行初始化

这3个类或函数是青睐仿照llama3改的,至于完整代码详见七月官网首页的:大模型项目开发线上营2期

3.2.2.4 Huggingface generate 长文本推理bug修复

截止2024年7月5日,Huggingface generate 推理 gemma2时,如果prompt token + 推理后的新token 超过了一个窗口的大小(4096)时,会报一个CUDA error的错误

为了解决长文本推理的问题,把gemma2 sdpa attention的forward函数中past_key_value.update 函数修改为下面22行-24行的代码即可

def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        
        < --------------省略代码-------------- >
      
        if past_key_value is not None:
            # sin and cos are specific to RoPE models; cache_position needed for the static cache
            cache_kwargs = {
                "sin": sin,
                "cos": cos,
                "sliding_window": self.sliding_window,
                "cache_position": cache_position,
            }
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs, 
                                                             # 新增参数sliding_window 以支持 _sliding_update 函数
                                                             sliding_window = self.sliding_window)
                                                             
        < --------------省略代码-------------- >

除了上述长文本的推理,截止2024年7月5日,Huggingface gemma2模型疑似与google gemma2 SWA和全局attention的交替顺序不一致:即 google gemma2 layer index为0,2,4,6...为SWAattention,1,3,5,7...为全局attention,而Huggingface 实现与其正好相反

不过,后来7月中旬左右,Huggingface transformers已经修复了这两个bug

3.2.3 模型迭代过程

  1. 显存占用上

  2. loss迭代上

    过程中,由于微调参数与llama3相差无几,故均选择与llama3迭代次数相同的1800为最终的checkpoint

3.3 推理结果

  • 本章节gemma2模型推理时,空项序列抑制惩罚系数均为0.95,均使用上文中微调后的gemma2-9b-it,仅仅是推理策略不同(是否加入logit soft caping、是否空项序列抑制)
  • gemma2 generate默认推理是未加logit soft caping,若使用该机制转为eager attention推理

3.3.1 未加入logit soft caping + 空项序列抑制推理结果

3.1.1.1 命中数pk评估

首先,PK下llama3

  • 下图左图:情况4微调gemma2-9b-it + PI 后仅加入序列抑制推理
  • 下图右图:情况4微调llama3-8b-instruct-8k +PI 后仅加入序列抑制推理

结论:同加入相同的空项抑制系数为0.95,gamma2的性能在论文审稿场景略强于llama3

其次,PK下GPT4

  • 下图左图:情况4微调gemma2-9b-it + PI 后仅加入序列抑制推理
  • 下图右图:基于7大项prompt工程提示的gpt4-1106

结论:加入空项序列抑制策略后,gemma2论文审稿的性能首次超越了基于7大项prompt提示的gpt4-1106,真不容易啊..

// 更多见七月官网首页的:大模型项目开发线上营2期

  • 26
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

v_JULY_v

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值