从Mistral Nemo到Large2 核心技术详解

从Mistral Nemo到Large2 核心技术详解

作者:Kevin吴嘉文,新加坡管理大学 信息技术硕士
原文:https://zhuanlan.zhihu.com/p/711294388

在本文中,梳理了 Mistral 系列模型(Mistral 7B, Mixtral 8x7B,Mixtral 8x22B,Mistral Nemo, Mistral Large 2)的关键信息,包括它们的主要特点、亮点以及相关资源链接。

Mistral 7B

官方博客:https://mistral.ai/news/announcing-mistral-7b/
mistral 7B 论文:https://arxiv.org/abs/2310.06825

Mistral 7B模型的亮点包括:

Sliding Window Attention

Mistral 采用的 window size 为 4096,而后一共有 32 层layer,那么采用 SWA 之后,理论上在进行 attention 的时候,理论上可以收集到约 131K tokens 的信息。(虽然论文里提到的 window size 是 4096,但 官方提供的 huggingface 上的权重[1] 中 max_position_embeddings 为 32768,且在新一点的版本中,比如 mistral-7b-instruct-v0.2[2] ,都不采用 sliding window 了)

图片

由于代用了固定的 attention 窗口大小,因此我们只需要一个大小为 W=window size 的 cache ,在计算第 i 个 token 的 cache 的时候,只需要覆盖 cache 中 i mod M 位置上的 hidden state 即可。

参考 huggingface 的 mistral 实现,Sliding window attention 通过 attention_mask 来控制:

# huggignface mistral attn mask 实现
def _update_causal_mask(
self,
    attention_mask: torch.Tensor,
    input_tensor: torch.Tensor,
    cache_position: torch.Tensor,
    past_key_values:Cache,
):
# ... 省略部分无关代码
    past_seen_tokens = cache_position[0]if past_key_values isnotNoneelse0
    using_static_cache = isinstance(past_key_values,StaticCache)
    using_sliding_window_cache = isinstance(past_key_values,SlidingWindowCache)

    dtype, device = input_tensor.dtype, input_tensor.device
    min_dtype = torch.finfo(dtype).min
    sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
        target_length = max(sequence_length,self.config.sliding_window)
# StaticCache
elif using_static_cache:
        target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
        target_length =(
            attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length +1
)

if attention_mask isnotNoneand attention_mask.dim()==4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max()!=0:
raiseValueError('Custom 4D attention mask should be passed in inverted form with max==0`')
        causal_mask = attention_mask
else:
        causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
        exclude_mask = torch.arange(target_length, device=device)> cache_position.reshape(-1,1)
ifself.config.sliding_window isnotNone:
ifnot using_sliding_window_cache or sequence_length >self.config.sliding_window:
                exclude_mask.bitwise_or_(
                    torch.arange(target_length, device=device)
<=(cache_position.reshape(-1,1)-self.config.sliding_window)
)
        causal_mask *= exclude_mask
        causal_mask = causal_mask[None,None,:,:].expand(input_tensor.shape[0],1,-1,-1)
if attention_mask isnotNone:
            causal_mask = causal_mask.clone()# copy to contiguous memory for in-place edit
if attention_mask.dim()==2:
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[:,:,:,:mask_length]+ attention_mask[:,None,None,:]
                padding_mask = padding_mask ==0
                causal_mask[:,:,:,:mask_length]= causal_mask[:,:,:,:mask_length].masked_fill(
                    padding_mask, min_dtype
)

return causal_mask

GQA (Grouped Query Attention)

Paper:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Abs:https://arxiv.org/abs/2305.13245

图片

grouped-query attention 指出,Multi-Query Attention[3] 提高了推理速度的同时,却可能极大地降低回复质量。因此根据上图,GQA 在推理速度和质量之间作了权衡。

以下为 GQA 文中的实验结果,值得注意的是论文中使用原 MHA checkpoint 转换为 GQA 权重后,还进行了额外的预训练:

图片

此外 Mistral,Llama2 的部分模型使用 GQA 时,采用的 kv head 数量似乎都是 8。

为什么现在大家都在用 MQA 和 GQA?[4] 文中提到 MQA 和 GQA 能获得巨大加速的一个点在于:GPU 内存强的限制。由于 MQA 和 GQA 都降低了内存中数据的读取量,减少了计算单元的等待时间,因此推理速度的提高比想象中的要快更多。

Mixtral 8*7B

论文:https://arxiv.org/abs/2401.04088
huggingface 模型权重:https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
官方博客:https://mistral.ai/news/mixtral-of-experts/
huggingface 模型代码:https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
混合专家模型基础(推荐):https://huggingface.co/blog/zh/moe

官方给出的评分来看,mixtral 8*7 和 GPT3.5 有的一比。

  • • 发布时间:23年12月

  • • 模型大小:8 个 expert MLP 层,一共45B 大小。

  • • 训练:除了预训练外,Mixtral MOE 后续还开源了一个经过 SFT + DPO 微调的版本。

  • • 模型效果:

图片
  • • 架构:Mixtral 的 MOE 架构类似于,在 MoE 模型中,只有 FFN 层被视为独立的专家,而模型的其他参数是共享的。大致参数为:

图片

参考 huggingface 中的 mixtral 和 mistral 实现对比,差异在于 mixtral 中将传统 transformer decoder layer 中的 FFN 替换为了 block_sparse_moe。

图片

主要逻辑

G(x)=Softmax(TopK(x⋅Wgate))final hidden states=∑i=0n−1G(x)i⋅Ei(x)

其中 Ei(x) 为专家对应的网络,具体展示为下面 huggingface 实现中的 MixtralBlockSparseTop2MLP。mixtral 中采用了 8 个 expert,每次推理使用选取 top 2 的 expert 进行推理。比如输入一句话 你好,今天,那么我们每个 token 都会选出 top 2 的 expert 来负责这个 token 的预测,因此在推理 你好,今天 时,有概率所有 expert 都会参与到计算当中,具体可以参考 MixtralSparseMoeBlock 的实现。

图片

mixtral 论文中提到专家分配在不同主题(如ArXiv论文、生物学和哲学文档)中没有明显的模式,只有在DM数学中显示出边际上的差异,这可能是由于其数据集的合成性质和有限的自然语言覆盖范围所致。router 在某些句法结构上表现出一定的结构化行为(比如 python 的 self 等),同时连续标记通常被分配给相同的专家。

huggingface 中的 mixtral 核心代码

class MixtralDecoderLayer(nn.Module):
def __init__(self, config:MixtralConfig, layer_idx:int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

self.block_sparse_moe =MixtralSparseMoeBlock(config)
self.input_layernorm =MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm =MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
        hidden_states: torch.Tensor,
        attention_mask:Optional[torch.Tensor]=None,
# 此处省略参数 ..
)->Tuple[torch.FloatTensor,Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:

        residual = hidden_states
        hidden_states =self.input_layernorm(hidden_states)
        hidden_states, self_attn_weights, present_key_value =self.self_attn(
# 此处省略参数 
)
        hidden_states = residual + hidden_states
        residual = hidden_states
        hidden_states =self.post_attention_layernorm(hidden_states)

# Mixtral 将原本的 hidden_states = self.FFN(hidden_states) 替换为了:
        hidden_states, router_logits =self.block_sparse_moe(hidden_states)

        hidden_states = residual + hidden_states
        outputs =(hidden_states,)

return outputs

huggingface 中 block_sparse_moe 的实现(省略部分次要代码):

class MixtralSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

self.gate = nn.Linear(self.hidden_dim,self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config)for _ in range(self.num_experts)])

self.jitter_noise = config.router_jitter_noise

def forward(self, hidden_states: torch.Tensor)-> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        router_logits =self.gate(hidden_states)# (batch * sequence_length, n_experts)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights,self.top_k, dim=-1)
        routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)

# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2,1,0)

# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
            expert_layer =self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)

            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# current_state: shape (n_i, hidden_dim)
# 所有 current_state 的长度 n 总和为 batch * sequence_length
            current_hidden_states = expert_layer(current_state)* routing_weights[top_x, idx,None]

# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits

其中:MixtralBlockSparseTop2MLP 长这样:

class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config:MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size

self.w1 = nn.Linear(self.hidden_dim,self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim,self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim,self.ffn_dim, bias=False)

self.act_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states):
        current_hidden_states =self.act_fn(self.w1(hidden_states))*self.w3(hidden_states)
        current_hidden_states =self.w2(current_hidden_states)
return current_hidden_states

推理部分的话,根据模型参数量 45B 来推理的话,如果用 fp16 的话推理的话,得需要至少 90GB 以上的显存,如果用 4 bit的话,30GB 显存就够了。量化的生成速度,可以参考这个 redis[5] 中的评论,大致为 :

推理精度设备速度 tokens/s
Q4_K_M单卡 4090 + 7950X3D20
Q4_K_M2 x 309048.26

如果有 100+GB 以上显存,可以用 vllm 快速搭建测试 api:

docker run --gpus all \
    -e HF_TOKEN=$HF_TOKEN -p 8000:8000 \
    ghcr.io/mistralai/mistral-src/vllm:latest \
    --host 0.0.0.0 \
    --model mistralai/Mixtral-8x7B-Instruct-v0.1 \
    --tensor-parallel-size 2 # 100+GB 显存 \
    --load-format pt # needed since both `pt` and `safetensors` are available

Nvidia TensorRT-LLM[6] 博客中,记录了 Mixtral 8*7B 的吞吐量测试(input and output sequence lengths of 128):

图片

input and output sequence lengths of 128

文中没有给出当 sequence lengths 最大时候的吞吐量,但根据上图数据,可以猜测 2个 H100 部署 8*7B 正常服务用户时,平均吞吐量应该可以大于 7500Tokens/秒,根据 H100 的功耗计算电费成本的话,生成 1M token 需要耗约为 0.02 度电。

Mixtral 8*22B

官方博客:https://mistral.ai/news/mixtral-8x22b/
huggingface 开源模型:https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1
  • • 架构:架构与 mixtral 8*7B 架构一样,在 huggingface 中使用的都是MixtralForCausalLM ,但 22B 的各方面参数大一点,比较特别的是 context window 从 32k 升级到了 65k, vocab_size 也更大一些。

  • • 支持 function calling,不过好像没有透露具体的 function calling 训练细节。

  • • 数学和 coding 能力明显超越 llama2 70B。

  • • 似乎对中文的支持不是很好。

图片

Mistral 团队开源的模型,都比较注重 coding 和 math 的能力,Mixtral 系列的模型在这方便表现也是比较好:

图片

Mistral Nemo

官方博客:https://mistral.ai/news/mistral-nemo/
huggingface 模型权重:https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407

Mistral Nemo 使用的也是 MistralForCausalLM 架构,与 mistral 7B 的差别为:Mistral Nemo 的 hidden_size 从 4096 变为 5120;max_position_embeddings 变为 1024000,num_hidden_layers 增加到 40, vocab_size 增加到 131072,不用 sliding window。

此外,Mistral Nemo 支持 function calling,采用了 Tekken 作为 tokenizer,比 SentencePiece 更高效(压缩率更高,官方描述是~30% more efficient at compressing,不确定是哪个方面的 efficient)

NVIDIA 在这个博客[7]中提到:Mistral Nemo 采用这样的设计,是为了能够适配单个NVIDIA L40S、NVIDIA GeForce RTX 4090或NVIDIA RTX 4500 GPU。模型采用 Megatron-LM[8] 训练,用了 3,072 个 H100 80GB 。

但光采用 FP16 加载整个 Mistral Nemo 就需要花 23 GB 显存,要是要跑满整个 context window size,除了量化外,还是得需要采用 offload 或者其他方法来推理

不过 mistral 官方把 12 B 的模型和其他 8B 的模型对比,感觉好像不太公平:

图片

Mistral Large 2

官方博客:https://mistral.ai/news/mistral-large-2407/
huggingface 模型权重:https://huggingface.co/mistralai/Mistral-Large-Instruct-2407

Mistral Large 2,参数量 123B,主打多语言以及 coding 能力。采用与 mistral 7B 一样的架构,huggingface 中同样使用 MistralForCausalLM;比较值得注意的是 context window size 为 131072,不用 sliding window。同样支持 function call。

Llama 3.1 刚出不久,就拿 Mistral Large 2 和别人来对比:

图片

在代码能力上,Mistral large 2 比 llama 3.1 平均效果更好。

图片

除了 coding 和数学外,在MT Bench 的评分也比 llama 3.1 高,平均生成的回复长度比 llama 3.1 要短

图片

同时,中文能力相对上一代 mistral large 有大步幅提升:

图片
引用链接

[1] huggingface 上的权重: https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json
[2] mistral-7b-instruct-v0.2: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json
[3] Multi-Query Attention: https://arxiv.org/pdf/1911.02150.pdf
[4] 为什么现在大家都在用 MQA 和 GQA?: https://zhuanlan.zhihu.com/p/647130255
[5] 这个 redis: https://www.reddit.com/r/LocalLLaMA/comments/18jslmf/tokens_per_second_mistral_8x7b_performance/?rdt=57036
[6] TensorRT-LLM: https://developer.nvidia.com/blog/achieving-high-mixtral-8x7b-performance-with-nvidia-h100-tensor-core-gpus-and-tensorrt-llm/?ncid=so-twit-928467/
[7] 这个博客: https://blogs.nvidia.com/blog/mistral-nvidia-ai-model/
[8] Megatron-LM: https://github.com/NVIDIA/Megatron-LM

包包算法笔记

包大人的大模型、深度学习、机器学习笔记。

152篇原创内容

公众号

相关文章

LLama 405B技术报告解读(二)

LLama 405B 技术报告解读

大模型Infra发展路径盘点

月之暗面kimi底层推理系统方案揭秘(二)

月之暗面kimi底层推理系统方案揭秘

谷歌Gemma-2大模型开源|技术报告解读

英伟达超大号340B大模型技术报告

大模型训练十戒

PPO vs DPO 对齐擂台的武林纷争

大模型的微调数据选择技巧(三)

大模型的微调数据选择技巧(二)

大模型微调数据选择和构造技巧

如何从零训练多模态大模型(预训练方向)

大模型预训练中的数据处理及思考

如何从零开始训练大模型(预训练方向)

从头预训练一只超迷你 LLaMA 3

大模型落地实用主义思考

一文逮尽知名开源大模型作弊!训题库...

大模型测试集作弊?ICLR论文将leak一网打尽!

大模型reward model的trick

大模型比赛kaggle Prompt Recovery方案解读

大模型微调经验和认知

大模型训练loss突刺原因和解决办法

大模型用8个7B超越70B的方法

大模型/AIGC/Agent必读百篇文章获取

大模型在任务型对话上有机会吗

国内AI大模型已近80个,哪个最有前途?

大模型如何修复badcase

大模型中的Scaling Law计算方法

垂直领域大模型落地思考

大模型的生产优化

大模型Kaggle比赛首秀冠军方案总结

大模型RLHF理论详细讲解

24家国内大模型面经

大模型面试八股含答案

大模型无标注对齐RLAIF讲解

大模型训练为什么用A100不用4090

大模型百川2技术报告细节分享

大模型来自面试的一些体会和分享

判断场景是否适合大模型

大模型微调技术报告汇总

大模型的幻觉问题

从零训练大模型教程

领域/场景大模型也太难训了吧

大模型开源社区的原子弹Llama2

大模型训练的一些坑点和判断

大模型RLHF的trick

大模型评测,也太难了吧

大模型面试八股

大模型微调样本构造的trick

大模型训练太难了!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI生成曾小健

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值