需要引入的Python包
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import argparse
import json
from typing import AsyncGenerator
from fastapi import BackgroundTasks, FastAPI, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse
import uvicorn
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.sampling_parms import SamplingParams
from vllm.utils import random_uuid
全局变量
TIMEOUT_KEEP_ALIVE = 5 # seconds
TIMEOUT_TO_PEVENT_DEADLOCK = 1 # seconds
app = FastAPI()
generate函数
@app.post("/generate")
async def generate(request: Request) -> Response:
"""
Generate Completion for the request.
The request should be a JSON object with the following fields:
- prompt: the prompt to use for the genreration.
- stream: whether to stream the results or not.
- other fields: the sampling parameters (See 'SamplingParams' for details).
"""
try:
request_dict = await request.json()
# contexts = request_dict.pop("contexts")
contexts = request_dict.get("data", {}).get("context")
salt_uuid = request_dict.pop("salt_uuid", "null")
prompt, message_doctor = process_context_qwen(contexts)
stgream = request_dict.pop("stream", False)
# sampling_params = SamplingParams(**request_dict)
# sampling_params = SamplingParams(n=1, temperature=0.95, top_p=0.65, top_k=20, max_tokens=128)
# sampling_params = SamplingParams(best_of=1, temperature=1e-6, top_p=1, top_k=-1, max_tokens=256, ignore_eos=False)
sampling_params = SamplingParams(n=1, temperature=0, best_of=5, top_p=1.0, top_k=-1, use_beam_search=True, max_tokens=128)
request_uuid = random_uuid()
results_generator = engine.generate(prompt, sampling_params, request_uuid)
# Streaming case
async def stream_results() -> AsyncGenerator[bytes, None]:
async for request_output in results_generator:
prompt = request_output.prompt
text_outputs = [
prompt + output.text for output in request_output.outputs
]
ret = {"text": text_outputs}
yield (json.dumps(ret) + "\0").encode("utf-8")
async def abort_request() -> None:
await engine.abort(request_id)
if stream:
background_tasks = BackgroundTasks()
# Abort the request if the client disconnects.
background_tasks.add_task(abort_request)
return StreamingResponse(stream_results(), background=background_tasks)
# Non-streaming case
final_output = None
async for request_output in results_generator:
if await request.is_disconnected():
# Abort the request if the client disconnect.
await engine.abort(request_id)
return Response(status_code=499)
final_output = request_output
assert final_output is not None
text_outputs = [output.text for output in final_output.outputs]
print(f"output:{final_output.outputs[0].text}")
ret = {"data": {"text": text_outputs}, "code": 5200, "message": "调试成功", "salt_uuid": salt_uuid}
except Exception as e:
ret = {"data": {"text": ""}, "code": 5201, "message": f"调用失败\n错误信息: {e}, ", "salt_uuid": salt_uuid}
return JSONResponse(ret)
qwen大模型prompt context处理函数
def process_context_qwen(contexts):
cur_index = 0
char_count = 0
for index, line_dict in enumerate(contexts[::-1]):
char_count += len(line_dict["message"])
if char_count >= 1024:
cur_index = len(contexts) - index - 1
break
converstaions_dataline = preprocessing(merged_json=contexts[cur_index:])[0]
query = ''
message_doctor = []
query = ''
for idx, datalines in enumerate(conversations_dataline["conversations"]):
if idx != len(converstaions_dataline["conversation"]) - 1:
if "human" in datalines:
human = datalines["human"]
query += f"<|im_start|>user\n{human}<|im_end|>\n"
if "assistant" in datalines:
assistant = datalines["assistant"]
message_doctor.append(assistant)
query += f"<|im_start|>assistant\n{assistant}<|im_end|>\n"
if "system" in datalines:
system = datalines["system"] + "。" if not datalines["system"].endswith("。") else datalines["system"]
query += f"<|im_start|>system\n{system}<|im_end|>\n"
else:
if "assistant" in datalines:
assistant = datalines["assistant"]
message_doctor.append(assistant)
query += f"<|im_start|>assistant\n{assistant}\n"
else:
human = datalines["human"]
query += f"<|im_start|>user\n{human}<|im_end|>\n<|im_start|>assistant\n"
return query, "\n".join(message_doctor)
prompt预处理函数
def preprocessing(merged_json):
assistant_prefix = "助手"
patient_prefix = "用户"
system_prefix = "system"
conversations_datalines = []
conversations_id = 1
conversations = []
content = ''
prev_role = None
for idx, sentence in enumerate(merged_json):
cur_role = None
if sentence["role"] == assistant_prefix:
cur_role = "assistant"
elif sentence["role"] == system_prefix:
cur_role = "system"
elif sentence["role"] == patient_prefix:
cur_role = "human"
if cur_role is not None and prev_role is not None and cur_role != prev_role:
conversations.append({prev_role: content})
content = ""
content += "\n" + sentence["message"].strip() if content else sentence["message"].strip()
if idx == len(merged_json) - 1:
if (cur_role is not None and list(conversations[-1].keys())[0] == cur_role) or cur_role is None:
conversations[-1][list(conversations[-1].keys())[0]] += '\n' + content
else:
conversations.append({cur_role: content})
if cur_role is not None:
prev_role = cur_role
if conversations:
conversations_datalines.append({
"conversations_id": conversations_id,
"category": "qwen",
"conversation": conversations,
"dataset": "yyds"
})
conversations_id += 1
return conversations_datalines
主函数
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=12356)
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()
engine_args = AsyncEngineArgs.from_cli_args(args)
engine = AsyncLLMEngine.from_engine_args(engine_args)
uvicorn.run(app,
host=args.host,
port=args.port,
log_level="info",
timeout_keep_alive=TIMEOUT_KEEP_ALIVE)
terminal命令
将上述代码按顺序组合在一个py文件中
python -m xxx.py --model your_model_path --tensor-parallel-size 4 --gpu-memory-utilization 0.95 --trust-remote-code --dtype half
输入数据格式
merged_json = [
{
"role": "system",
"message": "时间:上午8点,性别:女"
},
{
"role": "用户",
"message": "您好"
},
{
"role": "助手",
"message": "您需要什么吗"
}
]