大语言模型的后处理

后处理的输入

常规意义上的大模型处理流程

import torch
from transformers import LlamaForCausalLM, LlamaTokenizer

# 加载模型和tokenizer
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")

# 输入prompt
prompt = "Hello, I'm Claude. How can I assist you today?"
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# 前向传播获取logits
output = model(input_ids=input_ids)
logits = output.logits

# logits形状: (batch_size, sequence_length, vocab_size)
print(logits.shape)

后处理的输入是logits,其实准确说是hidden states,经过embedding table 映射后得到了最终的logits。

# 采样超参数
temperature = 0.7
top_k = 50
top_p = 0.95
repetition_penalty = 1.2

# 对logits进行处理
logits = logits[:, -1, :] / temperature  # 应用温度
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
filtered_logits = enforce_repetition_penalty(filtered_logits, input_ids, repetition_penalty)

# 从处理后的logits中采样token
probabilities = torch.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)

# 将新token添加到输入中,继续生成
input_ids = torch.cat((input_ids, next_token), dim=-1)
  • 首先定义了一些采样超参数,如温度(temperature)、top-k、top-p 和重复惩罚系数(repetition_penalty)。
  • 接下来, 对 logits 进行处理:
    • 应用温度缩放,控制输出的随机性。
    • 使用 top-k 和 top-p 过滤,保留概率最高的 k 个 token 或累积概率达到 p 的 token。
    • 应用重复惩罚,降低已生成 token 的概率,避免重复。

VLLM Sampler的处理

我们默认跑vllm benchmark test 的时候,sampling 参数配置:

    sampling_params = SamplingParams(
        n=args.n,
        # temperature=0.0 if args.use_beam_search else 1.0,
        temperature=0.0,
        top_p=1.0,
        use_beam_search=args.use_beam_search,
        ignore_eos=ignore_eos,
        max_tokens=max_tokens,
        repetition_penalty=args.repetition_penalty
    )

除了这些参数以外,SamplingParams(vllm/sampling_params.py)的默认配置我们主要关注:
在这里插入图片描述
其中由于temperature 设置为0,默认使用greedy sampling 方式进行logits 采样。
进入到Sampler 后处理(vllm/model_executor/layers/sampler.py,vllm/model_executor/sampling_metadata.py),do_top_p_top_kdo_min_p 采样bypass,最后softmax的输入shape 没有经过topk/p 的采样,输入shape为[bs, input_size, vocabulary_size]
在这里插入图片描述

因此,vocabulary size 如果太大,对softmax 性能的影响是一个很大的挑战。
从性能优化的角度考虑,可以先做一次logit 采样,通过设定合适的p/k 值保证模型输出精度。

  • 4
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值