核心思路和概念
s1: Simple test-time scaling 实际上就是在s1K数据集上对Qwen2.5-32B-Instruct语言模型进行有监督微调(代码中就用了trl库中的sft),并为其配备 “预算强制” 功能后,我们的模型s1-32B在竞赛数学问题(MATH 和 AIME24)上的表现比o1-preview高出27%
关于预算强制的策略:
- 限制token长度: 如果模型生成的思考词元数量超过了预期限制,我们会通过追加一个思考结束词元分隔符来强行终止思考过程。以这种方式结束思考会使模型转而生成答案
- 输出完了再加一个wait:如果我们希望模型在某个问题上花费更多测试阶段的计算资源,我们会抑制思考结束词元分隔符的生成,而是在模型当前的推理过程中追加 “等待”,以鼓励模型进行更多探索。
这个预算强制的策略理解起来还是有点晦涩,看github的代码就好了,简洁粗暴,以下做个摘录:
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
# Decide on a token limit for thinking; As the model's max tokens is 32768, 32000 usually ensures there is enough space for the model to still answer
MAX_TOKENS_THINKING = 32000
# Decide how often to ignore end-of-thinking token
NUM_IGNORE = 1
model = LLM(
"simplescaling/s1-32B",
tensor_parallel_size=2,
)
tok = AutoTokenizer.from_pretrained(
"simplescaling/s1-32B"
)
stop_token_ids = tok("<|im_end|>")["input_ids"]
sampling_params = SamplingParams(
max_tokens=32768,
min_tokens=0,
stop_token_ids=stop_token_ids,
skip_special_tokens=False,
temperature=0.0,
)
# For the exact raspberry sample in the paper, change
# model to `qfq/1k_qr_bt_dm_po_steps` (an earlier version of s1)
# & prompt to `How many r in raspberry?`
prompts = [
"How many r in raspberry",
]
for i, p in enumerate(prompts):
prompt = "<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + p + "<|im_end|>\n<|im_start|>assistant\n"
stop_token_ids = tok("<|im_start|><|im_end|>")["input_ids"]
sampling_params = SamplingParams(
max_tokens=MAX_TOKENS_THINKING,
min_tokens=0,
stop_token_ids=stop_token_ids,
skip_special_tokens=False,
temperature=0.0,
)
prompt += "<|im_start|>think"
o = model.generate(
prompt,
sampling_params=sampling_params
)
ignore_str = "Wait"
max_tokens_thinking_tmp = MAX_TOKENS_THINKING
# Num of times to skip stop token
for i in range(NUM_IGNORE):
max_tokens_thinking_tmp -= len(o[0].outputs[0].token_ids)
prompt += o[0].outputs[0].text + ignore_str
sampling_params = SamplingParams(
max_tokens=max_tokens_thinking_tmp,
min_tokens=1,
stop_token_ids=stop_token_ids,
skip_special_tokens=False,
temperature=0.0,
)
o = model.generate(
prompt,
sampling_params=sampling_params
)
### Final answer ###
prompt += o[0].outputs[0].text
stop_token_ids = tok("<|im_end|>")["input_ids"]
sampling_params = SamplingParams(
max_tokens=32768,
min_tokens=0,
stop_token_ids=stop_token_ids,
skip_special_tokens=False,
temperature=0.0,
)
o = model.generate(
prompt,
sampling_params=sampling_params,
)
print("With budget forcing:")
print(prompt + o[0].outputs[0].text)
参考文献
- 原文: https://arxiv.org/pdf/2501.19393
- 微信公众号翻译: https://mp.weixin.qq.com/s/L62T6rQZwVFhCLjeeN5tfg
- github: https://github.com/simplescaling/s1