项目实训——大模型篇2:实现RESTful API访问

ming_api_server参照openAI的RESTful API和fastchat,实现了提供springboot后端访问的对话生成接口。其通过与model_worker通信来获取和解析模型生成的文本内容,并能够接收和返回json格式化信息

1、文本生成

仿照OpenAI的RESTful API风格,访问地址/v1/chat/completions时可以获得模型聊天的输出

该函数解析来自外部的json信息,对其进行错误诊断,并根据得到的model参数请求controller以获得目标模型的model_worker所在的地址,并将其余参数传递给get_gen_params函数以将json中的转换成python对象形式的参数。其中参数n的作用是控制最终生成文本的条数。

利用定义好的对象对返回的聊天信息进行包装,以便返回json格式化文本

@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    error_check_ret = check_requests(request)
    if error_check_ret is not None:
        return error_check_ret
    worker_addr = await get_worker_address(request.model)
    # 参数解析
    gen_params = await get_gen_params(
        request.model,
        request.messages,
        temperature=request.temperature,
        top_p=request.top_p,
        top_k=request.top_k,
        presence_penalty=request.presence_penalty,
        frequency_penalty=request.frequency_penalty,
        max_tokens=request.max_tokens,
        echo=False,
        stop=request.stop,
    )
    choices = []
    chat_completions = []
    # 发送消息
    for i in range(request.n):
        content = asyncio.create_task(generate_completion(gen_params, worker_addr))

        chat_completions.append(content)
    try:
        all_tasks = await asyncio.gather(*chat_completions)
    except Exception as e:
        return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
    for i, content in enumerate(all_tasks):
        if isinstance(content, str):
            content = json.loads(content)
        if content["error_code"] != 0:
            return create_error_response(content["error_code"], content["text"])
        choices.append(
            ChatCompletionResponseChoice(
                index=i,
                message=ChatMessage(role="assistant", content=content["text"]),
                finish_reason=content.get("finish_reason", "stop"),
            )
        )
    return ChatCompletionResponse(model=request.model, choices=choices, usage=None)

generate_completion是实际与model_worker通信的函数,其以流的方式向model_worker发出生成文本的POST请求,并对model_worker返回的字节码进行解码

async def generate_completion(payload: Dict[str, Any], worker_addr: str):
    async with httpx.AsyncClient() as client:
        async with client.stream("POST", worker_addr + "/worker_generate_stream",
                                 headers=headers, json=payload, timeout=WORKER_API_TIMEOUT) as response:
            content = await response.aread()
            return content.replace(b'\0', b'').decode()

2、多轮对话

多轮对话的核心实现逻辑非常简单,就是把模型过往的对话记录缓存起来,生成的时候将这些聊天记录全部放在请求中。对话记录遵从{role:"xxx",content:"xxx"}的形式,role为user时说明这条信息是用户提问,role为assistant说明是系统的回答

在py端只有参数解析的部分,其将抽取用户和大模型的对话记录提取关键词以做进一步生成

聊天记录缓存的功能我们选择在springboot后端实现。

完整代码

import asyncio
import argparse
import json
import os
from typing import Generator, Optional, Union, Dict, List, Any

import aiohttp
import fastapi
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
import httpx

from ming.conversations import conv_templates, get_default_conv_template, SeparatorStyle
from pydantic_settings import BaseSettings
import shortuuid
import tiktoken
import uvicorn

from fastchat.constants import (
    WORKER_API_TIMEOUT,
    ErrorCode,
)

from fastchat.protocol.openai_api_protocol import (
    ChatCompletionRequest,
    ChatCompletionResponse,
    ChatMessage,
    ChatCompletionResponseChoice,
    ErrorResponse,
    UsageInfo
)


conv_template_map = {}

fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)


async def fetch_remote(url, pload=None, name=None):
    async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
        async with session.post(url, json=pload) as response:
            chunks = []
            if response.status != 200:
                ret = {
                    "text": f"{response.reason}",
                    "error_code": ErrorCode.INTERNAL_ERROR,
                }
                return json.dumps(ret)

            async for chunk, _ in response.content.iter_chunks():
                chunks.append(chunk)
        output = b"".join(chunks).replace(b'\x00', b'')
    if name is not None:
        res = json.loads(output)
        if name != "":
            res = res[name]
        return res
    print(output)
    return output


class AppSettings(BaseSettings):
    controller_address: str = "http://localhost:21001"


app_settings = AppSettings()
app = fastapi.FastAPI()
headers = {"User-Agent": "MING API Server"}


def create_error_response(code: int, message: str) -> JSONResponse:
    return JSONResponse(
        ErrorResponse(message=message, code=code).model_dump(), status_code=400
    )


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
    return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))



def check_requests(request) -> Optional[JSONResponse]:
    # 参数检查
    if request.max_tokens is not None and request.max_tokens <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
        )
    if request.n is not None and request.n <= 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.n} is less than the minimum of 1 - 'n'",
        )
    if request.temperature is not None and request.temperature < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is less than the minimum of 0 - 'temperature'",
        )
    if request.temperature is not None and request.temperature > 2:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
        )
    if request.top_p is not None and request.top_p < 0:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is less than the minimum of 0 - 'top_p'",
        )
    if request.top_p is not None and request.top_p > 1:
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_p} is greater than the maximum of 1 - 'top_p'",
        )
    if request.top_k is not None and (request.top_k > -1 and request.top_k < 1):
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.",
        )
    if request.stop is not None and (
            not isinstance(request.stop, str) and not isinstance(request.stop, list)
    ):
        return create_error_response(
            ErrorCode.PARAM_OUT_OF_RANGE,
            f"{request.stop} is not valid under any of the given schemas - 'stop'",
        )
    return None
def _add_to_set(s, new_stop):
    if not s:
        return
    if isinstance(s, str):
        new_stop.add(s)
    else:
        new_stop.update(s)


async def get_gen_params(
        model_name: str,
        messages: Union[str, List[Dict[str, str]]],
        *,
        temperature: float,
        top_p: float,
        top_k: Optional[int],
        presence_penalty: Optional[float],
        frequency_penalty: Optional[float],
        max_tokens: Optional[int],
        echo: Optional[bool],
        stop: Optional[Union[str, List[str]]],
) -> Dict[str, Any]:
    conv = conv_templates["qwen"].copy()

    if isinstance(messages, str):
        prompt = messages
    else:
        for message in messages:
            msg_role = message["role"]
            if msg_role == "system":
                conv.system = message['content']
            elif msg_role == "user":
                if type(message["content"]) == list:
                    text_list = [
                        item["text"]
                        for item in message["content"]
                        if item["type"] == "text"
                    ]
                    text = "\n".join(text_list)
                    conv.append_message(conv.roles[0], text)
                else:
                    conv.append_message(conv.roles[0], message["content"])
            elif msg_role == "assistant":
                conv.append_message(conv.roles[1], message["content"])
            else:
                raise ValueError(f"Unknown role: {msg_role}")
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

    gen_params = {
        "model": model_name,
        "prompt": prompt,
        "temperature": temperature,
        "top_p": top_p,
        "top_k": top_k,
        "presence_penalty": presence_penalty,
        "frequency_penalty": frequency_penalty,
        "max_new_tokens": max_tokens,
        "echo": echo,
        "stop_token_ids": stop
    }

    new_stop = set()
    _add_to_set(stop, new_stop)

    gen_params["stop"] = list(new_stop)

    return gen_params

async def get_worker_address(model_name: str) -> str:
    controller_address = app_settings.controller_address
    worker_addr = await fetch_remote(
        controller_address + "/get_worker_address", {"model": model_name}, "address"
    )
    if worker_addr == "":
        raise ValueError(f"No available worker for {model_name}")
    return worker_addr

@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
    error_check_ret = check_requests(request)
    if error_check_ret is not None:
        return error_check_ret
    worker_addr = await get_worker_address(request.model)
    # 参数解析
    gen_params = await get_gen_params(
        request.model,
        request.messages,
        temperature=request.temperature,
        top_p=request.top_p,
        top_k=request.top_k,
        presence_penalty=request.presence_penalty,
        frequency_penalty=request.frequency_penalty,
        max_tokens=request.max_tokens,
        echo=False,
        stop=request.stop,
    )
    choices = []
    chat_completions = []
    # 发送消息
    for i in range(request.n):
        content = asyncio.create_task(generate_completion(gen_params, worker_addr))

        chat_completions.append(content)
    try:
        all_tasks = await asyncio.gather(*chat_completions)
    except Exception as e:
        return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
    usage = UsageInfo()
    for i, content in enumerate(all_tasks):
        if isinstance(content, str):
            content = json.loads(content)
        if content["error_code"] != 0:
            return create_error_response(content["error_code"], content["text"])
        choices.append(
            ChatCompletionResponseChoice(
                index=i,
                message=ChatMessage(role="assistant", content=content["text"]),
                finish_reason=content.get("finish_reason", "stop"),
            )
        )
    return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)

async def generate_completion(payload: Dict[str, Any], worker_addr: str):
    async with httpx.AsyncClient() as client:
        async with client.stream("POST", worker_addr + "/worker_generate_stream",
                                 headers=headers, json=payload, timeout=WORKER_API_TIMEOUT) as response:
            content = await response.aread()
            return content.replace(b'\0', b'').decode()


def create_api_server():
    parser = argparse.ArgumentParser(
        description="Simple RESTful API server."
    )
    parser.add_argument("--host", type=str, default="localhost", help="host name")
    parser.add_argument("--port", type=int, default=6006, help="port number")
    parser.add_argument(
        "--controller-address", type=str, default="http://localhost:21001"
    )
    parser.add_argument(
        "--allow-credentials", action="store_true", help="allow credentials"
    )
    parser.add_argument(
        "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
    )
    parser.add_argument(
        "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
    )
    parser.add_argument(
        "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
    )

    args = parser.parse_args()
    app.add_middleware(
        CORSMiddleware,
        allow_origins=args.allowed_origins,
        allow_credentials=args.allow_credentials,
        allow_methods=args.allowed_methods,
        allow_headers=args.allowed_headers,
    )
    app_settings.controller_address = args.controller_address


    return args
if __name__ == "__main__":
    args = create_api_server()
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")

  • 9
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值