背景
使用 vllm 可以加速模型的推理过程;vllm 加速的代码也很少,只需要调用包即可实现,没有太大的学习成本,而且好处很多,可以极大提升模型的推理速度;
简介
使用 modelscope 的 chatglm3-6B,调用 vllm 加速推理;
我的显卡显存为 24G ;
chatglm3-6B,如果不用vllm,我的显存不够,必须使用half
才能放进显存 ;
使用 vllm 后,vllm 加载的大模型模型权重占用空间会小一点;不使用half
,恰好能放进我的显存空间;
实操
By default, vLLM downloads model from HuggingFace. If you would like to use models from ModelScope in the following examples, please set the environment variable:
export VLLM_USE_MODELSCOPE=True
from vllm import LLM, SamplingParams
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048)
SamplingParams?
Init signature:
SamplingParams(
n: int = 1,
best_of: Optional[int] = None,
presence_penalty: float = 0.0,
frequency_penalty: float = 0.0,
repetition_penalty: float = 1.0,
temperature: float = 1.0,
top_p: float = 1.0,
top_k: int = -1,
min_p: float = 0.0,
use_beam_search: bool = False,
length_penalty: float = 1.0,
early_stopping: Union[bool, str] = False,
stop: Union[List[str], str, NoneType] = None,
stop_token_ids: Optional[List[int]] = None,
include_stop_str_in_output: bool = False,
ignore_eos: bool = False,
max_tokens: Optional[int] = 16,
logprobs: Optional[int] = None,
prompt_logprobs: Optional[int] = None,
skip_special_tokens: bool = True,
spaces_between_special_tokens: bool = True,
logits_processors: Optional[List[Callable[[List[int], torch.Tensor], torch.Tensor]]] = None,
) -> None
Docstring:
Sampling parameters for text generation.
利用SamplingParams
的参数,控制大模型的输出结果;
使用大模型针对如下 prompt 做推理:
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"今天天气真好,咱们出去",
"明天就要开学了,我的作业还没写完,",
"请你介绍一下你自己。AI:"
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=2048)
model_path = 'ZhipuAI/chatglm3-6b'
llm = LLM(
model=model_path,
trust_remote_code=True,
tokenizer=model_path,
tokenizer_mode='slow',
tensor_parallel_size=1
)
outputs = llm.generate(prompts, sampling_params)
推理耗时7秒,速度快了很多;
查看推理结果
# Print the outputs.
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
代码公开
本文代码发布在 github:https://github.com/JieShenAI/csdn/blob/main/24/04/vllm/vllm_demo.ipynb