ming_api_server参照openAI的RESTful API和fastchat,实现了提供springboot后端访问的对话生成接口。其通过与model_worker通信来获取和解析模型生成的文本内容,并能够接收和返回json格式化信息
1、文本生成
仿照OpenAI的RESTful API风格,访问地址/v1/chat/completions时可以获得模型聊天的输出
该函数解析来自外部的json信息,对其进行错误诊断,并根据得到的model参数请求controller以获得目标模型的model_worker所在的地址,并将其余参数传递给get_gen_params函数以将json中的转换成python对象形式的参数。其中参数n的作用是控制最终生成文本的条数。
利用定义好的对象对返回的聊天信息进行包装,以便返回json格式化文本
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
worker_addr = await get_worker_address(request.model)
# 参数解析
gen_params = await get_gen_params(
request.model,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
)
choices = []
chat_completions = []
# 发送消息
for i in range(request.n):
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
chat_completions.append(content)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
for i, content in enumerate(all_tasks):
if isinstance(content, str):
content = json.loads(content)
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=content["text"]),
finish_reason=content.get("finish_reason", "stop"),
)
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=None)
generate_completion是实际与model_worker通信的函数,其以流的方式向model_worker发出生成文本的POST请求,并对model_worker返回的字节码进行解码
async def generate_completion(payload: Dict[str, Any], worker_addr: str):
async with httpx.AsyncClient() as client:
async with client.stream("POST", worker_addr + "/worker_generate_stream",
headers=headers, json=payload, timeout=WORKER_API_TIMEOUT) as response:
content = await response.aread()
return content.replace(b'\0', b'').decode()
2、多轮对话
多轮对话的核心实现逻辑非常简单,就是把模型过往的对话记录缓存起来,生成的时候将这些聊天记录全部放在请求中。对话记录遵从{role:"xxx",content:"xxx"}的形式,role为user时说明这条信息是用户提问,role为assistant说明是系统的回答
在py端只有参数解析的部分,其将抽取用户和大模型的对话记录提取关键词以做进一步生成
聊天记录缓存的功能我们选择在springboot后端实现。
完整代码
import asyncio
import argparse
import json
import os
from typing import Generator, Optional, Union, Dict, List, Any
import aiohttp
import fastapi
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
import httpx
from ming.conversations import conv_templates, get_default_conv_template, SeparatorStyle
from pydantic_settings import BaseSettings
import shortuuid
import tiktoken
import uvicorn
from fastchat.constants import (
WORKER_API_TIMEOUT,
ErrorCode,
)
from fastchat.protocol.openai_api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatMessage,
ChatCompletionResponseChoice,
ErrorResponse,
UsageInfo
)
conv_template_map = {}
fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)
async def fetch_remote(url, pload=None, name=None):
async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
async with session.post(url, json=pload) as response:
chunks = []
if response.status != 200:
ret = {
"text": f"{response.reason}",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return json.dumps(ret)
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks).replace(b'\x00', b'')
if name is not None:
res = json.loads(output)
if name != "":
res = res[name]
return res
print(output)
return output
class AppSettings(BaseSettings):
controller_address: str = "http://localhost:21001"
app_settings = AppSettings()
app = fastapi.FastAPI()
headers = {"User-Agent": "MING API Server"}
def create_error_response(code: int, message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, code=code).model_dump(), status_code=400
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
def check_requests(request) -> Optional[JSONResponse]:
# 参数检查
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'top_p'",
)
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
return None
def _add_to_set(s, new_stop):
if not s:
return
if isinstance(s, str):
new_stop.add(s)
else:
new_stop.update(s)
async def get_gen_params(
model_name: str,
messages: Union[str, List[Dict[str, str]]],
*,
temperature: float,
top_p: float,
top_k: Optional[int],
presence_penalty: Optional[float],
frequency_penalty: Optional[float],
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
) -> Dict[str, Any]:
conv = conv_templates["qwen"].copy()
if isinstance(messages, str):
prompt = messages
else:
for message in messages:
msg_role = message["role"]
if msg_role == "system":
conv.system = message['content']
elif msg_role == "user":
if type(message["content"]) == list:
text_list = [
item["text"]
for item in message["content"]
if item["type"] == "text"
]
text = "\n".join(text_list)
conv.append_message(conv.roles[0], text)
else:
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"max_new_tokens": max_tokens,
"echo": echo,
"stop_token_ids": stop
}
new_stop = set()
_add_to_set(stop, new_stop)
gen_params["stop"] = list(new_stop)
return gen_params
async def get_worker_address(model_name: str) -> str:
controller_address = app_settings.controller_address
worker_addr = await fetch_remote(
controller_address + "/get_worker_address", {"model": model_name}, "address"
)
if worker_addr == "":
raise ValueError(f"No available worker for {model_name}")
return worker_addr
@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
worker_addr = await get_worker_address(request.model)
# 参数解析
gen_params = await get_gen_params(
request.model,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
)
choices = []
chat_completions = []
# 发送消息
for i in range(request.n):
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
chat_completions.append(content)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if isinstance(content, str):
content = json.loads(content)
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=content["text"]),
finish_reason=content.get("finish_reason", "stop"),
)
)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def generate_completion(payload: Dict[str, Any], worker_addr: str):
async with httpx.AsyncClient() as client:
async with client.stream("POST", worker_addr + "/worker_generate_stream",
headers=headers, json=payload, timeout=WORKER_API_TIMEOUT) as response:
content = await response.aread()
return content.replace(b'\0', b'').decode()
def create_api_server():
parser = argparse.ArgumentParser(
description="Simple RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=6006, help="port number")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
app_settings.controller_address = args.controller_address
return args
if __name__ == "__main__":
args = create_api_server()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")