从Mistral Nemo到Large2 核心技术详解
作者:Kevin吴嘉文,新加坡管理大学 信息技术硕士
原文:https://zhuanlan.zhihu.com/p/711294388
在本文中,梳理了 Mistral 系列模型(Mistral 7B, Mixtral 8x7B,Mixtral 8x22B,Mistral Nemo, Mistral Large 2)的关键信息,包括它们的主要特点、亮点以及相关资源链接。
Mistral 7B
官方博客:https://mistral.ai/news/announcing-mistral-7b/
mistral 7B 论文:https://arxiv.org/abs/2310.06825
Mistral 7B模型的亮点包括:
Sliding Window Attention
Mistral 采用的 window size 为 4096,而后一共有 32 层layer,那么采用 SWA 之后,理论上在进行 attention 的时候,理论上可以收集到约 131K tokens 的信息。(虽然论文里提到的 window size 是 4096,但 官方提供的 huggingface 上的权重[1] 中 max_position_embeddings 为 32768,且在新一点的版本中,比如 mistral-7b-instruct-v0.2[2] ,都不采用 sliding window 了)
由于代用了固定的 attention 窗口大小,因此我们只需要一个大小为 W=window size 的 cache ,在计算第 i 个 token 的 cache 的时候,只需要覆盖 cache 中 i mod M 位置上的 hidden state 即可。
参考 huggingface 的 mistral 实现,Sliding window attention 通过 attention_mask 来控制:
# huggignface mistral attn mask 实现
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values:Cache,
):
# ... 省略部分无关代码
past_seen_tokens = cache_position[0]if past_key_values isnotNoneelse0
using_static_cache = isinstance(past_key_values,StaticCache)
using_sliding_window_cache = isinstance(past_key_values,SlidingWindowCache)
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
# SlidingWindowCache
if using_sliding_window_cache:
target_length = max(sequence_length,self.config.sliding_window)
# StaticCache
elif using_static_cache:
target_length = past_key_values.get_max_length()
# DynamicCache or no cache
else:
target_length =(
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length +1
)
if attention_mask isnotNoneand attention_mask.dim()==4:
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
if attention_mask.max()!=0:
raiseValueError('Custom 4D attention mask should be passed in inverted form with max==0`')
causal_mask = attention_mask
else:
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
exclude_mask = torch.arange(target_length, device=device)> cache_position.reshape(-1,1)
ifself.config.sliding_window isnotNone:
ifnot using_sliding_window_cache or sequence_length >self.config.sliding_window:
exclude_mask.bitwise_or_(
torch.arange(target_length, device=device)
<=(cache_position.reshape(-1,1)-self.config.sliding_window)
)
causal_mask *= exclude_mask
causal_mask = causal_mask[None,None,:,:].expand(input_tensor.shape[0],1,-1,-1)
if attention_mask isnotNone:
causal_mask = causal_mask.clone()# copy to contiguous memory for in-place edit
if attention_mask.dim()==2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:,:,:,:mask_length]+ attention_mask[:,None,None,:]
padding_mask = padding_mask ==0
causal_mask[:,:,:,:mask_length]= causal_mask[:,:,:,:mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
GQA (Grouped Query Attention)
Paper:GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints
Abs:https://arxiv.org/abs/2305.13245
grouped-query attention 指出,Multi-Query Attention[3] 提高了推理速度的同时,却可能极大地降低回复质量。因此根据上图,GQA 在推理速度和质量之间作了权衡。
以下为 GQA 文中的实验结果,值得注意的是论文中使用原 MHA checkpoint 转换为 GQA 权重后,还进行了额外的预训练:
此外 Mistral,Llama2 的部分模型使用 GQA 时,采用的 kv head 数量似乎都是 8。
为什么现在大家都在用 MQA 和 GQA?[4] 文中提到 MQA 和 GQA 能获得巨大加速的一个点在于:GPU 内存强的限制。由于 MQA 和 GQA 都降低了内存中数据的读取量,减少了计算单元的等待时间,因此推理速度的提高比想象中的要快更多。
Mixtral 8*7B
论文:https://arxiv.org/abs/2401.04088
huggingface 模型权重:https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1
官方博客:https://mistral.ai/news/mixtral-of-experts/
huggingface 模型代码:https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py
混合专家模型基础(推荐):https://huggingface.co/blog/zh/moe
官方给出的评分来看,mixtral 8*7 和 GPT3.5 有的一比。
-
• 发布时间:23年12月
-
• 模型大小:8 个 expert MLP 层,一共45B 大小。
-
• 训练:除了预训练外,Mixtral MOE 后续还开源了一个经过 SFT + DPO 微调的版本。
-
• 模型效果:
-
• 架构:Mixtral 的 MOE 架构类似于,在 MoE 模型中,只有 FFN 层被视为独立的专家,而模型的其他参数是共享的。大致参数为:
参考 huggingface 中的 mixtral 和 mistral 实现对比,差异在于 mixtral 中将传统 transformer decoder layer 中的 FFN 替换为了 block_sparse_moe。
主要逻辑
G(x)=Softmax(TopK(x⋅Wgate))final hidden states=∑i=0n−1G(x)i⋅Ei(x)
其中 Ei(x) 为专家对应的网络,具体展示为下面 huggingface 实现中的 MixtralBlockSparseTop2MLP。mixtral 中采用了 8 个 expert,每次推理使用选取 top 2 的 expert 进行推理。比如输入一句话 你好,今天,那么我们每个 token 都会选出 top 2 的 expert 来负责这个 token 的预测,因此在推理 你好,今天 时,有概率所有 expert 都会参与到计算当中,具体可以参考 MixtralSparseMoeBlock 的实现。
mixtral 论文中提到专家分配在不同主题(如ArXiv论文、生物学和哲学文档)中没有明显的模式,只有在DM数学中显示出边际上的差异,这可能是由于其数据集的合成性质和有限的自然语言覆盖范围所致。router 在某些句法结构上表现出一定的结构化行为(比如 python 的 self 等),同时连续标记通常被分配给相同的专家。
huggingface 中的 mixtral 核心代码
class MixtralDecoderLayer(nn.Module):
def __init__(self, config:MixtralConfig, layer_idx:int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
self.block_sparse_moe =MixtralSparseMoeBlock(config)
self.input_layernorm =MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm =MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
hidden_states: torch.Tensor,
attention_mask:Optional[torch.Tensor]=None,
# 此处省略参数 ..
)->Tuple[torch.FloatTensor,Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
hidden_states =self.input_layernorm(hidden_states)
hidden_states, self_attn_weights, present_key_value =self.self_attn(
# 此处省略参数
)
hidden_states = residual + hidden_states
residual = hidden_states
hidden_states =self.post_attention_layernorm(hidden_states)
# Mixtral 将原本的 hidden_states = self.FFN(hidden_states) 替换为了:
hidden_states, router_logits =self.block_sparse_moe(hidden_states)
hidden_states = residual + hidden_states
outputs =(hidden_states,)
return outputs
huggingface 中 block_sparse_moe 的实现(省略部分次要代码):
class MixtralSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok
self.gate = nn.Linear(self.hidden_dim,self.num_experts, bias=False)
self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config)for _ in range(self.num_experts)])
self.jitter_noise = config.router_jitter_noise
def forward(self, hidden_states: torch.Tensor)-> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
router_logits =self.gate(hidden_states)# (batch * sequence_length, n_experts)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights,self.top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
# we cast back to the input dtype
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)
# One hot encode the selected experts to create an expert mask
# this will be used to easily index which expert is going to be sollicitated
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2,1,0)
# Loop over all available experts in the model and perform the computation on each expert
for expert_idx in range(self.num_experts):
expert_layer =self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
# Index the correct hidden states and compute the expert hidden state for
# the current expert. We need to make sure to multiply the output hidden
# states by `routing_weights` on the corresponding tokens (top-1 and top-2)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
# current_state: shape (n_i, hidden_dim)
# 所有 current_state 的长度 n 总和为 batch * sequence_length
current_hidden_states = expert_layer(current_state)* routing_weights[top_x, idx,None]
# However `index_add_` only support torch tensors for indexing so we'll use
# the `top_x` tensor here.
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
return final_hidden_states, router_logits
其中:MixtralBlockSparseTop2MLP 长这样:
class MixtralBlockSparseTop2MLP(nn.Module):
def __init__(self, config:MixtralConfig):
super().__init__()
self.ffn_dim = config.intermediate_size
self.hidden_dim = config.hidden_size
self.w1 = nn.Linear(self.hidden_dim,self.ffn_dim, bias=False)
self.w2 = nn.Linear(self.ffn_dim,self.hidden_dim, bias=False)
self.w3 = nn.Linear(self.hidden_dim,self.ffn_dim, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, hidden_states):
current_hidden_states =self.act_fn(self.w1(hidden_states))*self.w3(hidden_states)
current_hidden_states =self.w2(current_hidden_states)
return current_hidden_states
推理部分的话,根据模型参数量 45B 来推理的话,如果用 fp16 的话推理的话,得需要至少 90GB 以上的显存,如果用 4 bit的话,30GB 显存就够了。量化的生成速度,可以参考这个 redis[5] 中的评论,大致为 :
推理精度 | 设备 | 速度 tokens/s |
Q4_K_M | 单卡 4090 + 7950X3D | 20 |
Q4_K_M | 2 x 3090 | 48.26 |
如果有 100+GB 以上显存,可以用 vllm 快速搭建测试 api:
docker run --gpus all \
-e HF_TOKEN=$HF_TOKEN -p 8000:8000 \
ghcr.io/mistralai/mistral-src/vllm:latest \
--host 0.0.0.0 \
--model mistralai/Mixtral-8x7B-Instruct-v0.1 \
--tensor-parallel-size 2 # 100+GB 显存 \
--load-format pt # needed since both `pt` and `safetensors` are available
Nvidia TensorRT-LLM[6] 博客中,记录了 Mixtral 8*7B 的吞吐量测试(input and output sequence lengths of 128):
input and output sequence lengths of 128
文中没有给出当 sequence lengths 最大时候的吞吐量,但根据上图数据,可以猜测 2个 H100 部署 8*7B 正常服务用户时,平均吞吐量应该可以大于 7500Tokens/秒,根据 H100 的功耗计算电费成本的话,生成 1M token 需要耗约为 0.02 度电。
Mixtral 8*22B
官方博客:https://mistral.ai/news/mixtral-8x22b/
huggingface 开源模型:https://huggingface.co/mistralai/Mixtral-8x22B-Instruct-v0.1
-
• 架构:架构与 mixtral 8*7B 架构一样,在 huggingface 中使用的都是MixtralForCausalLM ,但 22B 的各方面参数大一点,比较特别的是 context window 从 32k 升级到了 65k, vocab_size 也更大一些。
-
• 支持 function calling,不过好像没有透露具体的 function calling 训练细节。
-
• 数学和 coding 能力明显超越 llama2 70B。
-
• 似乎对中文的支持不是很好。
Mistral 团队开源的模型,都比较注重 coding 和 math 的能力,Mixtral 系列的模型在这方便表现也是比较好:
Mistral Nemo
官方博客:https://mistral.ai/news/mistral-nemo/
huggingface 模型权重:https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407
Mistral Nemo 使用的也是 MistralForCausalLM 架构,与 mistral 7B 的差别为:Mistral Nemo 的 hidden_size 从 4096 变为 5120;max_position_embeddings 变为 1024000,num_hidden_layers 增加到 40, vocab_size 增加到 131072,不用 sliding window。
此外,Mistral Nemo 支持 function calling,采用了 Tekken 作为 tokenizer,比 SentencePiece 更高效(压缩率更高,官方描述是~30% more efficient at compressing,不确定是哪个方面的 efficient)
NVIDIA 在这个博客[7]中提到:Mistral Nemo 采用这样的设计,是为了能够适配单个NVIDIA L40S、NVIDIA GeForce RTX 4090或NVIDIA RTX 4500 GPU。模型采用 Megatron-LM[8] 训练,用了 3,072 个 H100 80GB 。
但光采用 FP16 加载整个 Mistral Nemo 就需要花 23 GB 显存,要是要跑满整个 context window size,除了量化外,还是得需要采用 offload 或者其他方法来推理
不过 mistral 官方把 12 B 的模型和其他 8B 的模型对比,感觉好像不太公平:
Mistral Large 2
官方博客:https://mistral.ai/news/mistral-large-2407/
huggingface 模型权重:https://huggingface.co/mistralai/Mistral-Large-Instruct-2407
Mistral Large 2,参数量 123B,主打多语言以及 coding 能力。采用与 mistral 7B 一样的架构,huggingface 中同样使用 MistralForCausalLM;比较值得注意的是 context window size 为 131072,不用 sliding window。同样支持 function call。
Llama 3.1 刚出不久,就拿 Mistral Large 2 和别人来对比:
在代码能力上,Mistral large 2 比 llama 3.1 平均效果更好。
除了 coding 和数学外,在MT Bench 的评分也比 llama 3.1 高,平均生成的回复长度比 llama 3.1 要短
同时,中文能力相对上一代 mistral large 有大步幅提升:
引用链接
[1]
huggingface 上的权重: https://huggingface.co/mistralai/mathstral-7B-v0.1/blob/main/config.json[2]
mistral-7b-instruct-v0.2: https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2/blob/main/config.json[3]
Multi-Query Attention: https://arxiv.org/pdf/1911.02150.pdf[4]
为什么现在大家都在用 MQA 和 GQA?: https://zhuanlan.zhihu.com/p/647130255[5]
这个 redis: https://www.reddit.com/r/LocalLLaMA/comments/18jslmf/tokens_per_second_mistral_8x7b_performance/?rdt=57036[6]
TensorRT-LLM: https://developer.nvidia.com/blog/achieving-high-mixtral-8x7b-performance-with-nvidia-h100-tensor-core-gpus-and-tensorrt-llm/?ncid=so-twit-928467/[7]
这个博客: https://blogs.nvidia.com/blog/mistral-nvidia-ai-model/[8]
Megatron-LM: https://github.com/NVIDIA/Megatron-LM
包包算法笔记
包大人的大模型、深度学习、机器学习笔记。
152篇原创内容
公众号
相关文章