controller.py
"""
好,那再给你看看controller.py 就是安装serve库的代码
A controller manages distributed workers.
It sends worker addresses to clients.
"""
import argparse
import asyncio
import dataclasses
from enum import Enum, auto
import json
import logging
import os
import time
from typing import List, Union
import threading
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import numpy as np
import requests
import uvicorn
from fastchat.constants import (
CONTROLLER_HEART_BEAT_EXPIRATION,
WORKER_API_TIMEOUT,
ErrorCode,
SERVER_ERROR_MSG,
)
from fastchat.utils import build_logger
logger = build_logger("controller", "controller.log")
class DispatchMethod(Enum):
LOTTERY = auto()
SHORTEST_QUEUE = auto()
@classmethod
def from_str(cls, name):
if name == "lottery":
return cls.LOTTERY
elif name == "shortest_queue":
return cls.SHORTEST_QUEUE
else:
raise ValueError(f"Invalid dispatch method")
@dataclasses.dataclass
class WorkerInfo:
model_names: List[str]
speed: int
queue_length: int
check_heart_beat: bool
last_heart_beat: str
def heart_beat_controller(controller):
while True:
time.sleep(CONTROLLER_HEART_BEAT_EXPIRATION)
controller.remove_stale_workers_by_expiration()
class Controller:
def __init__(self, dispatch_method: str):
# Dict[str -> WorkerInfo]
self.worker_info = {}
self.dispatch_method = DispatchMethod.from_str(dispatch_method)
self.heart_beat_thread = threading.Thread(
target=heart_beat_controller, args=(self,)
)
self.heart_beat_thread.start()
def register_worker(
self, worker_name: str, check_heart_beat: bool, worker_status: dict
):
if worker_name not in self.worker_info:
logger.info(f"Register a new worker: {worker_name}")
else:
logger.info(f"Register an existing worker: {worker_name}")
if not worker_status:
worker_status = self.get_worker_status(worker_name)
if not worker_status:
return False
self.worker_info[worker_name] = WorkerInfo(
worker_status["model_names"],
worker_status["speed"],
worker_status["queue_length"],
check_heart_beat,
time.time(),
)
logger.info(f"Register done: {worker_name}, {worker_status}")
return True
def get_worker_status(self, worker_name: str):
try:
r = requests.post(worker_name + "/worker_get_status", timeout=5)
except requests.exceptions.RequestException as e:
logger.error(f"Get status fails: {worker_name}, {e}")
return None
if r.status_code != 200:
logger.error(f"Get status fails: {worker_name}, {r}")
return None
return r.json()
def remove_worker(self, worker_name: str):
del self.worker_info[worker_name]
def refresh_all_workers(self):
old_info = dict(self.worker_info)
self.worker_info = {}
for w_name, w_info in old_info.items():
if not self.register_worker(w_name, w_info.check_heart_beat, None):
logger.info(f"Remove stale worker: {w_name}")
def list_models(self):
model_names = set()
for w_name, w_info in self.worker_info.items():
model_names.update(w_info.model_names)
return list(model_names)
def get_worker_address(self, model_name: str):
if self.dispatch_method == DispatchMethod.LOTTERY:
worker_names = []
worker_speeds = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_speeds.append(w_info.speed)
worker_speeds = np.array(worker_speeds, dtype=np.float32)
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
if True: # Directly return address
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
return worker_name
# Check status before returning
while True:
pt = np.random.choice(np.arange(len(worker_names)), p=worker_speeds)
worker_name = worker_names[pt]
if self.get_worker_status(worker_name):
break
else:
self.remove_worker(worker_name)
worker_speeds[pt] = 0
norm = np.sum(worker_speeds)
if norm < 1e-4:
return ""
worker_speeds = worker_speeds / norm
continue
return worker_name
elif self.dispatch_method == DispatchMethod.SHORTEST_QUEUE:
worker_names = []
worker_qlen = []
for w_name, w_info in self.worker_info.items():
if model_name in w_info.model_names:
worker_names.append(w_name)
worker_qlen.append(w_info.queue_length / w_info.speed)
if len(worker_names) == 0:
return ""
min_index = np.argmin(worker_qlen)
w_name = worker_names[min_index]
self.worker_info[w_name].queue_length += 1
logger.info(
f"names: {worker_names}, queue_lens: {worker_qlen}, ret: {w_name}"
)
return w_name
else:
raise ValueError(f"Invalid dispatch method: {self.dispatch_method}")
def receive_heart_beat(self, worker_name: str, queue_length: int):
if worker_name not in self.worker_info:
logger.info(f"Receive unknown heart beat. {worker_name}")
return False
self.worker_info[worker_name].queue_length = queue_length
self.worker_info[worker_name].last_heart_beat = time.time()
logger.info(f"Receive heart beat. {worker_name}")
return True
def remove_stale_workers_by_expiration(self):
expire = time.time() - CONTROLLER_HEART_BEAT_EXPIRATION
to_delete = []
for worker_name, w_info in self.worker_info.items():
if w_info.check_heart_beat and w_info.last_heart_beat < expire:
to_delete.append(worker_name)
for worker_name in to_delete:
self.remove_worker(worker_name)
def handle_no_worker(self, params):
logger.info(f"no worker: {params['model']}")
ret = {
"text": SERVER_ERROR_MSG,
"error_code": ErrorCode.CONTROLLER_NO_WORKER,
}
return json.dumps(ret).encode() + b"\0"
def handle_worker_timeout(self, worker_address):
logger.info(f"worker timeout: {worker_address}")
ret = {
"text": SERVER_ERROR_MSG,
"error_code": ErrorCode.CONTROLLER_WORKER_TIMEOUT,
}
return json.dumps(ret).encode() + b"\0"
# Let the controller act as a worker to achieve hierarchical
# management. This can be used to connect isolated sub networks.
def worker_api_get_status(self):
model_names = set()
speed = 0
queue_length = 0
for w_name in self.worker_info:
worker_status = self.get_worker_status(w_name)
if worker_status is not None:
model_names.update(worker_status["model_names"])
speed += worker_status["speed"]
queue_length += worker_status["queue_length"]
model_names = sorted(list(model_names))
return {
"model_names": model_names,
"speed": speed,
"queue_length": queue_length,
}
def worker_api_generate_stream(self, params):
worker_addr = self.get_worker_address(params["model"])
if not worker_addr:
yield self.handle_no_worker(params)
try:
response = requests.post(
worker_addr + "/worker_generate_stream",
json=params,
stream=True,
timeout=WORKER_API_TIMEOUT,
)
for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"):
if chunk:
yield chunk + b"\0"
except requests.exceptions.RequestException as e:
yield self.handle_worker_timeout(worker_addr)
app = FastAPI()
@app.post("/register_worker")
async def register_worker(request: Request):
data = await request.json()
controller.register_worker(
data["worker_name"], data["check_heart_beat"], data.get("worker_status", None)
)
@app.post("/refresh_all_workers")
async def refresh_all_workers():
models = controller.refresh_all_workers()
@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}
@app.post("/get_worker_address")
async def get_worker_address(request: Request):
data = await request.json()
addr = controller.get_worker_address(data["model"])
return {"address": addr}
@app.post("/receive_heart_beat")
async def receive_heart_beat(request: Request):
data = await request.json()
exist = controller.receive_heart_beat(data["worker_name"], data["queue_length"])
return {"exist": exist}
@app.post("/worker_generate_stream")
async def worker_api_generate_stream(request: Request):
params = await request.json()
generator = controller.worker_api_generate_stream(params)
return StreamingResponse(generator)
@app.post("/worker_get_status")
async def worker_api_get_status(request: Request):
return controller.worker_api_get_status()
@app.get("/test_connection")
async def worker_api_get_status(request: Request):
return "success"
def create_controller():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21001)
parser.add_argument(
"--dispatch-method",
type=str,
choices=["lottery", "shortest_queue"],
default="shortest_queue",
)
parser.add_argument(
"--ssl",
action="store_true",
required=False,
default=False,
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
)
args = parser.parse_args()
logger.info(f"args: {args}")
controller = Controller(args.dispatch_method)
return args, controller
if __name__ == "__main__":
args, controller = create_controller()
if args.ssl:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
ssl_keyfile=os.environ["SSL_KEYFILE"],
ssl_certfile=os.environ["SSL_CERTFILE"],
)
else:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
实现
根据提供的代码,可以看出这是一个使用FastAPI构建的控制器(Controller)。控制器负责管理分布式的工作节点(Workers),并向客户端发送工作节点的地址。
具体实现如下:
1. 导入所需的库和模块。
2. 定义日志记录器和控制器(Controller)类。
3. 在控制器(Controller)类的`__init__`方法中,初始化工作节点信息字典(`worker_info`)和调度方法(`dispatch_method`)。
4. 实现`register_worker`方法,用于注册工作节点。
5. 实现`get_worker_status`方法,用于获取工作节点的状态。
6. 实现`remove_worker`方法,用于移除工作节点。
7. 实现`refresh_all_workers`方法,用于刷新所有工作节点。
8. 实现`list_models`方法,用于列出所有可用的模型。
9. 实现`get_worker_address`方法,根据调度方法选择一个工作节点,并返回其地址。
10. 实现`receive_heart_beat`方法,用于接收工作节点的心跳信息。
11. 实现`remove_stale_workers_by_expiration`方法,根据心跳信息的过期时间移除失效的工作节点。
12. 实现`handle_no_worker`方法和`handle_worker_timeout`方法,用于处理没有可用工作节点和工作节点超时的情况。
13. 实现`worker_api_get_status`方法和`worker_api_generate_stream`方法,用于模拟工作节点的API接口。
14. 创建FastAPI应用实例。
15. 定义各个API的路由和处理函数。
16. 创建控制器(Controller)实例。
17. 根据命令行参数启动FastAPI应用。
以上是根据提供的代码推测出的使用serve库构建控制器的具体实现方式。如果需要更详细的实现细节,建议参考FastChat框架的文档和源代码。
model_worker.py
"""
以下是model_worker.py 也就是启动本地LLM的具体实现的FastChat框架的内部实现细节
A model worker that executes the model.
"""
import argparse
import base64
import gc
import json
import os
from typing import List, Optional
import uuid
import torch
import torch.nn.functional as F
from transformers import set_seed
import uvicorn
from fastchat.constants import ErrorCode, SERVER_ERROR_MSG
from fastchat.model.model_adapter import (
load_model,
add_model_args,
get_generate_stream_function,
)
from fastchat.modules.awq import AWQConfig
from fastchat.modules.exllama import ExllamaConfig
from fastchat.modules.gptq import GptqConfig
from fastchat.serve.base_model_worker import BaseModelWorker, app
from fastchat.utils import (
build_logger,
get_context_length,
str_to_torch_dtype,
)
worker_id = str(uuid.uuid4())[:8]
logger = build_logger("model_worker", f"model_worker_{worker_id}.log")
class ModelWorker(BaseModelWorker):
def __init__(
self,
controller_addr: str,
worker_addr: str,
worker_id: str,
model_path: str,
model_names: List[str],
limit_worker_concurrency: int,
no_register: bool,
device: str,
num_gpus: int,
max_gpu_memory: str,
dtype: Optional[torch.dtype] = None,
load_8bit: bool = False,
cpu_offloading: bool = False,
gptq_config: Optional[GptqConfig] = None,
awq_config: Optional[AWQConfig] = None,
exllama_config: Optional[ExllamaConfig] = None,
stream_interval: int = 2,
conv_template: Optional[str] = None,
embed_in_truncate: bool = False,
seed: Optional[int] = None,
**kwargs,
):
super().__init__(
controller_addr,
worker_addr,
worker_id,
model_path,
model_names,
limit_worker_concurrency,
conv_template=conv_template,
)
logger.info(f"Loading the model {self.model_names} on worker {worker_id} ...")
self.model, self.tokenizer = load_model(
model_path,
device=device,
num_gpus=num_gpus,
max_gpu_memory=max_gpu_memory,
dtype=dtype,
load_8bit=load_8bit,
cpu_offloading=cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
)
self.device = device
if self.tokenizer.pad_token == None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.context_len = get_context_length(self.model.config)
self.generate_stream_func = get_generate_stream_function(self.model, model_path)
self.stream_interval = stream_interval
self.embed_in_truncate = embed_in_truncate
self.seed = seed
if not no_register:
self.init_heart_beat()
def generate_stream_gate(self, params):
self.call_ct += 1
try:
if self.seed is not None:
set_seed(self.seed)
for output in self.generate_stream_func(
self.model,
self.tokenizer,
params,
self.device,
self.context_len,
self.stream_interval,
):
ret = {
"text": output["text"],
"error_code": 0,
}
if "usage" in output:
ret["usage"] = output["usage"]
if "finish_reason" in output:
ret["finish_reason"] = output["finish_reason"]
if "logprobs" in output:
ret["logprobs"] = output["logprobs"]
yield json.dumps(ret).encode() + b"\0"
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
yield json.dumps(ret).encode() + b"\0"
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
yield json.dumps(ret).encode() + b"\0"
def generate_gate(self, params):
for x in self.generate_stream_gate(params):
pass
return json.loads(x[:-1].decode())
def __process_embed_chunk(self, input_ids, attention_mask, **model_type_dict):
if model_type_dict.get("is_bert"):
model_output = self.model(input_ids)
if model_type_dict.get("is_robert"):
data = model_output.last_hidden_state
else:
data = model_output[0]
elif model_type_dict.get("is_t5"):
model_output = self.model(input_ids, decoder_input_ids=input_ids)
data = model_output.encoder_last_hidden_state
else:
model_output = self.model(input_ids, output_hidden_states=True)
if model_type_dict.get("is_chatglm"):
data = model_output.hidden_states[-1].transpose(0, 1)
else:
data = model_output.hidden_states[-1]
mask = attention_mask.unsqueeze(-1).expand(data.size()).float()
masked_embeddings = data * mask
sum_embeddings = torch.sum(masked_embeddings, dim=1)
token_num = torch.sum(attention_mask).item()
return sum_embeddings, token_num
def __encode_base64(self, embeddings: torch.Tensor) -> List[str]:
embeddings = embeddings.cpu()
return [
base64.b64encode(e.numpy().tobytes()).decode("utf-8") for e in embeddings
]
@torch.inference_mode()
def get_embeddings(self, params):
self.call_ct += 1
try:
tokenizer = self.tokenizer
ret = {"embedding": [], "token_num": 0}
model_type_dict = {
"is_llama": "llama" in str(type(self.model)),
"is_t5": "t5" in str(type(self.model)),
"is_chatglm": "chatglm" in str(type(self.model)),
"is_bert": "bert" in str(type(self.model)),
"is_robert": "robert" in str(type(self.model)),
}
if self.embed_in_truncate:
encoding = tokenizer.batch_encode_plus(
params["input"],
padding=True,
truncation="longest_first",
return_tensors="pt",
max_length=self.context_len,
)
else:
encoding = tokenizer.batch_encode_plus(
params["input"], padding=True, return_tensors="pt"
)
input_ids = encoding["input_ids"].to(self.device)
attention_mask = input_ids != tokenizer.pad_token_id
base64_encode = params.get("encoding_format", None)
if self.embed_in_truncate:
chunk_embeddings, token_num = self.__process_embed_chunk(
input_ids, attention_mask, **model_type_dict
)
embedding = chunk_embeddings / token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = token_num
else:
all_embeddings = []
all_token_num = 0
for i in range(0, input_ids.size(1), self.context_len):
chunk_input_ids = input_ids[:, i : i + self.context_len]
chunk_attention_mask = attention_mask[:, i : i + self.context_len]
chunk_embeddings, token_num = self.__process_embed_chunk(
chunk_input_ids, chunk_attention_mask, **model_type_dict
)
all_embeddings.append(chunk_embeddings)
all_token_num += token_num
all_embeddings_tensor = torch.stack(all_embeddings)
embedding = torch.sum(all_embeddings_tensor, dim=0) / all_token_num
normalized_embeddings = F.normalize(embedding, p=2, dim=1)
ret["token_num"] = all_token_num
if base64_encode == "base64":
out_embeddings = self.__encode_base64(normalized_embeddings)
else:
out_embeddings = normalized_embeddings.tolist()
ret["embedding"] = out_embeddings
gc.collect()
torch.cuda.empty_cache()
if self.device == "xpu":
torch.xpu.empty_cache()
if self.device == "npu":
torch.npu.empty_cache()
except torch.cuda.OutOfMemoryError as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
}
except (ValueError, RuntimeError) as e:
ret = {
"text": f"{SERVER_ERROR_MSG}\n\n({e})",
"error_code": ErrorCode.INTERNAL_ERROR,
}
return ret
def create_model_worker():
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--port", type=int, default=21002)
parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
add_model_args(parser)
parser.add_argument(
"--model-names",
type=lambda s: s.split(","),
help="Optional display comma separated names",
)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument("--embed-in-truncate", action="store_true")
parser.add_argument(
"--limit-worker-concurrency",
type=int,
default=5,
help="Limit the model concurrency to prevent OOM.",
)
parser.add_argument("--stream-interval", type=int, default=2)
parser.add_argument("--no-register", action="store_true")
parser.add_argument(
"--seed",
type=int,
default=None,
help="Overwrite the random seed for each generation.",
)
args = parser.parse_args()
logger.info(f"args: {args}")
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
)
awq_config = AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
)
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
)
else:
exllama_config = None
worker = ModelWorker(
args.controller_address,
args.worker_address,
worker_id,
args.model_path,
args.model_names,
args.limit_worker_concurrency,
no_register=args.no_register,
device=args.device,
num_gpus=args.num_gpus,
max_gpu_memory=args.max_gpu_memory,
dtype=str_to_torch_dtype(args.dtype),
load_8bit=args.load_8bit,
cpu_offloading=args.cpu_offloading,
gptq_config=gptq_config,
awq_config=awq_config,
exllama_config=exllama_config,
stream_interval=args.stream_interval,
conv_template=args.conv_template,
embed_in_truncate=args.embed_in_truncate,
seed=args.seed,
)
return args, worker
if __name__ == "__main__":
args, worker = create_model_worker()
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
代码分析
根据提供的代码,可以看出启动本地LLM的具体实现是通过创建一个`ModelWorker`类的实例,并在`__init__`方法中加载模型和初始化相关参数。具体步骤如下:
1. 导入所需的库和模块。
2. 定义日志记录器和工作器ID。
3. 创建`ModelWorker`类,继承自`BaseModelWorker`类。
4. 在`ModelWorker`类的`__init__`方法中,加载模型和初始化相关参数,包括模型路径、模型名称、设备类型、GPU数量等。
5. 实现`generate_stream_gate`方法和`generate_gate`方法,用于生成LLM的输出。
6. 实现`__process_embed_chunk`方法和`get_embeddings`方法,用于获取LLM的嵌入向量。
7. 创建`create_model_worker`函数,用于创建`ModelWorker`实例。
8. 在`__main__`函数中,调用`create_model_worker`函数创建`ModelWorker`实例,并使用`uvicorn`库运行FastAPI应用。
以上是根据提供的代码推测出的启动本地LLM的具体实现方式。如果需要更详细的实现细节,建议参考FastChat框架的文档和源代码。
openai_api_server.py
"""A server that provides OpenAI-compatible RESTful APIs. It supports:
还有openai_api_server.py 是启动API:
返回的key是用来与Autogen对接的
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat)
- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions)
- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings)
Usage:
python3 -m fastchat.serve.openai_api_server
"""
import asyncio
import argparse
import json
import logging
import os
from typing import Generator, Optional, Union, Dict, List, Any
import aiohttp
import fastapi
from fastapi import Depends, HTTPException
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer
import httpx
from pydantic import BaseSettings
import shortuuid
import tiktoken
import uvicorn
from fastchat.constants import (
WORKER_API_TIMEOUT,
WORKER_API_EMBEDDING_BATCH_SIZE,
ErrorCode,
)
from fastchat.conversation import Conversation, SeparatorStyle
from fastchat.protocol.openai_api_protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse,
ChatMessage,
ChatCompletionResponseChoice,
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
DeltaMessage,
CompletionResponseStreamChoice,
CompletionStreamResponse,
EmbeddingsRequest,
EmbeddingsResponse,
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)
from fastchat.protocol.api_protocol import (
APIChatCompletionRequest,
APITokenCheckRequest,
APITokenCheckResponse,
APITokenCheckResponseItem,
)
logger = logging.getLogger(__name__)
conv_template_map = {}
fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600)
async def fetch_remote(url, pload=None, name=None):
async with aiohttp.ClientSession(timeout=fetch_timeout) as session:
async with session.post(url, json=pload) as response:
chunks = []
async for chunk, _ in response.content.iter_chunks():
chunks.append(chunk)
output = b"".join(chunks)
if name is not None:
res = json.loads(output)
if name != "":
res = res[name]
return res
return output
class AppSettings(BaseSettings):
# The address of the model controller.
controller_address: str = "http://localhost:21001"
api_keys: Optional[List[str]] = None
app_settings = AppSettings()
app = fastapi.FastAPI()
headers = {"User-Agent": "FastChat API Server"}
get_bearer_token = HTTPBearer(auto_error=False)
async def check_api_key(
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token),
) -> str:
if app_settings.api_keys:
if auth is None or (token := auth.credentials) not in app_settings.api_keys:
raise HTTPException(
status_code=401,
detail={
"error": {
"message": "",
"type": "invalid_request_error",
"param": None,
"code": "invalid_api_key",
}
},
)
return token
else:
# api_keys not set; allow all
return None
def create_error_response(code: int, message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, code=code).dict(), status_code=400
)
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc))
async def check_model(request) -> Optional[JSONResponse]:
controller_address = app_settings.controller_address
ret = None
models = await fetch_remote(controller_address + "/list_models", None, "models")
if request.model not in models:
ret = create_error_response(
ErrorCode.INVALID_MODEL,
f"Only {'&&'.join(models)} allowed now, your model {request.model}",
)
return ret
async def check_length(request, prompt, max_tokens, worker_addr):
if (
not isinstance(max_tokens, int) or max_tokens <= 0
): # model worker not support max_tokens=None
max_tokens = 1024 * 1024
context_len = await fetch_remote(
worker_addr + "/model_details", {"model": request.model}, "context_length"
)
token_num = await fetch_remote(
worker_addr + "/count_token",
{"model": request.model, "prompt": prompt},
"count",
)
return min(max_tokens, context_len - token_num)
def check_requests(request) -> Optional[JSONResponse]:
# Check all params
if request.max_tokens is not None and request.max_tokens <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'",
)
if request.n is not None and request.n <= 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.n} is less than the minimum of 1 - 'n'",
)
if request.temperature is not None and request.temperature < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is less than the minimum of 0 - 'temperature'",
)
if request.temperature is not None and request.temperature > 2:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.temperature} is greater than the maximum of 2 - 'temperature'",
)
if request.top_p is not None and request.top_p < 0:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is less than the minimum of 0 - 'top_p'",
)
if request.top_p is not None and request.top_p > 1:
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.top_p} is greater than the maximum of 1 - 'temperature'",
)
if request.stop is not None and (
not isinstance(request.stop, str) and not isinstance(request.stop, list)
):
return create_error_response(
ErrorCode.PARAM_OUT_OF_RANGE,
f"{request.stop} is not valid under any of the given schemas - 'stop'",
)
return None
def process_input(model_name, inp):
if isinstance(inp, str):
inp = [inp]
elif isinstance(inp, list):
if isinstance(inp[0], int):
decoding = tiktoken.model.encoding_for_model(model_name)
inp = [decoding.decode(inp)]
elif isinstance(inp[0], list):
decoding = tiktoken.model.encoding_for_model(model_name)
inp = [decoding.decode(text) for text in inp]
return inp
def _add_to_set(s, new_stop):
if not s:
return
if isinstance(s, str):
new_stop.add(s)
else:
new_stop.update(s)
async def get_gen_params(
model_name: str,
worker_addr: str,
messages: Union[str, List[Dict[str, str]]],
*,
temperature: float,
top_p: float,
max_tokens: Optional[int],
echo: Optional[bool],
stop: Optional[Union[str, List[str]]],
) -> Dict[str, Any]:
conv = await get_conv(model_name, worker_addr)
conv = Conversation(
name=conv["name"],
system_template=conv["system_template"],
system_message=conv["system_message"],
roles=conv["roles"],
messages=list(conv["messages"]), # prevent in-place modification
offset=conv["offset"],
sep_style=SeparatorStyle(conv["sep_style"]),
sep=conv["sep"],
sep2=conv["sep2"],
stop_str=conv["stop_str"],
stop_token_ids=conv["stop_token_ids"],
)
if isinstance(messages, str):
prompt = messages
else:
for message in messages:
msg_role = message["role"]
if msg_role == "system":
conv.set_system_message(message["content"])
elif msg_role == "user":
conv.append_message(conv.roles[0], message["content"])
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message["content"])
else:
raise ValueError(f"Unknown role: {msg_role}")
# Add a blank message for the assistant.
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
gen_params = {
"model": model_name,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_new_tokens": max_tokens,
"echo": echo,
"stop_token_ids": conv.stop_token_ids,
}
new_stop = set()
_add_to_set(stop, new_stop)
_add_to_set(conv.stop_str, new_stop)
gen_params["stop"] = list(new_stop)
logger.debug(f"==== request ====\n{gen_params}")
return gen_params
async def get_worker_address(model_name: str) -> str:
"""
Get worker address based on the requested model
:param model_name: The worker's model name
:return: Worker address from the controller
:raises: :class:`ValueError`: No available worker for requested model
"""
controller_address = app_settings.controller_address
worker_addr = await fetch_remote(
controller_address + "/get_worker_address", {"model": model_name}, "address"
)
# No available worker
if worker_addr == "":
raise ValueError(f"No available worker for {model_name}")
logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}")
return worker_addr
async def get_conv(model_name: str, worker_addr: str):
conv_template = conv_template_map.get((worker_addr, model_name))
if conv_template is None:
conv_template = await fetch_remote(
worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv"
)
conv_template_map[(worker_addr, model_name)] = conv_template
return conv_template
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
async def show_available_models():
controller_address = app_settings.controller_address
ret = await fetch_remote(controller_address + "/refresh_all_workers")
models = await fetch_remote(controller_address + "/list_models", None, "models")
models.sort()
# TODO: return real model permission details
model_cards = []
for m in models:
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()]))
return ModelList(data=model_cards)
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key)])
async def create_chat_completion(request: ChatCompletionRequest):
"""Creates a completion for the chat message"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
worker_addr = await get_worker_address(request.model)
gen_params = await get_gen_params(
request.model,
worker_addr,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
)
gen_params["max_new_tokens"] = await check_length(
request,
gen_params["prompt"],
gen_params["max_new_tokens"],
worker_addr,
)
if request.stream:
generator = chat_completion_stream_generator(
request.model, gen_params, request.n, worker_addr
)
return StreamingResponse(generator, media_type="text/event-stream")
choices = []
chat_completions = []
for i in range(request.n):
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
chat_completions.append(content)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=content["text"]),
finish_reason=content.get("finish_reason", "stop"),
)
)
if "usage" in content:
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
async def chat_completion_stream_generator(
model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str
) -> Generator[str, Any, None]:
"""
Event stream format:
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format
"""
id = f"chatcmpl-{shortuuid.random()}"
finish_stream_events = []
for i in range(n):
# First chunk with role
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role="assistant"),
finish_reason=None,
)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
previous_text = ""
async for content in generate_completion_stream(gen_params, worker_addr):
if content["error_code"] != 0:
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
decoded_unicode = content["text"].replace("\ufffd", "")
delta_text = decoded_unicode[len(previous_text) :]
previous_text = (
decoded_unicode
if len(decoded_unicode) > len(previous_text)
else previous_text
)
if len(delta_text) == 0:
delta_text = None
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(content=delta_text),
finish_reason=content.get("finish_reason", None),
)
chunk = ChatCompletionStreamResponse(
id=id, choices=[choice_data], model=model_name
)
if delta_text is None:
if content.get("finish_reason", None) is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
@app.post("/v1/completions", dependencies=[Depends(check_api_key)])
async def create_completion(request: CompletionRequest):
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
request.prompt = process_input(request.model, request.prompt)
worker_addr = await get_worker_address(request.model)
for text in request.prompt:
max_tokens = await check_length(request, text, request.max_tokens, worker_addr)
if isinstance(max_tokens, int) and max_tokens < request.max_tokens:
request.max_tokens = max_tokens
if request.stream:
generator = generate_completion_stream_generator(
request, request.n, worker_addr
)
return StreamingResponse(generator, media_type="text/event-stream")
else:
text_completions = []
for text in request.prompt:
gen_params = await get_gen_params(
request.model,
worker_addr,
text,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
)
for i in range(request.n):
content = asyncio.create_task(
generate_completion(gen_params, worker_addr)
)
text_completions.append(content)
try:
all_tasks = await asyncio.gather(*text_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
choices = []
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
choices.append(
CompletionResponseChoice(
index=i,
text=content["text"],
logprobs=content.get("logprobs", None),
finish_reason=content.get("finish_reason", "stop"),
)
)
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return CompletionResponse(
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage)
)
async def generate_completion_stream_generator(
request: CompletionRequest, n: int, worker_addr: str
):
model_name = request.model
id = f"cmpl-{shortuuid.random()}"
finish_stream_events = []
for text in request.prompt:
for i in range(n):
previous_text = ""
gen_params = await get_gen_params(
request.model,
worker_addr,
text,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
echo=request.echo,
stop=request.stop,
)
async for content in generate_completion_stream(gen_params, worker_addr):
if content["error_code"] != 0:
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
return
decoded_unicode = content["text"].replace("\ufffd", "")
delta_text = decoded_unicode[len(previous_text) :]
previous_text = (
decoded_unicode
if len(decoded_unicode) > len(previous_text)
else previous_text
)
# todo: index is not apparent
choice_data = CompletionResponseStreamChoice(
index=i,
text=delta_text,
logprobs=content.get("logprobs", None),
finish_reason=content.get("finish_reason", None),
)
chunk = CompletionStreamResponse(
id=id,
object="text_completion",
choices=[choice_data],
model=model_name,
)
if len(delta_text) == 0:
if content.get("finish_reason", None) is not None:
finish_stream_events.append(chunk)
continue
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
# There is not "content" field in the last delta message, so exclude_none to exclude field "content".
for finish_chunk in finish_stream_events:
yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n"
yield "data: [DONE]\n\n"
async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str):
controller_address = app_settings.controller_address
async with httpx.AsyncClient() as client:
delimiter = b"\0"
async with client.stream(
"POST",
worker_addr + "/worker_generate_stream",
headers=headers,
json=payload,
timeout=WORKER_API_TIMEOUT,
) as response:
# content = await response.aread()
buffer = b""
async for raw_chunk in response.aiter_raw():
buffer += raw_chunk
while (chunk_end := buffer.find(delimiter)) >= 0:
chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :]
if not chunk:
continue
yield json.loads(chunk.decode())
async def generate_completion(payload: Dict[str, Any], worker_addr: str):
return await fetch_remote(worker_addr + "/worker_generate", payload, "")
@app.post("/v1/embeddings", dependencies=[Depends(check_api_key)])
@app.post("/v1/engines/{model_name}/embeddings", dependencies=[Depends(check_api_key)])
async def create_embeddings(request: EmbeddingsRequest, model_name: str = None):
"""Creates embeddings for the text"""
if request.model is None:
request.model = model_name
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
request.input = process_input(request.model, request.input)
data = []
token_num = 0
batch_size = WORKER_API_EMBEDDING_BATCH_SIZE
batches = [
request.input[i : min(i + batch_size, len(request.input))]
for i in range(0, len(request.input), batch_size)
]
for num_batch, batch in enumerate(batches):
payload = {
"model": request.model,
"input": batch,
"encoding_format": request.encoding_format,
}
embedding = await get_embedding(payload)
if "error_code" in embedding and embedding["error_code"] != 0:
return create_error_response(embedding["error_code"], embedding["text"])
data += [
{
"object": "embedding",
"embedding": emb,
"index": num_batch * batch_size + i,
}
for i, emb in enumerate(embedding["embedding"])
]
token_num += embedding["token_num"]
return EmbeddingsResponse(
data=data,
model=request.model,
usage=UsageInfo(
prompt_tokens=token_num,
total_tokens=token_num,
completion_tokens=None,
),
).dict(exclude_none=True)
async def get_embedding(payload: Dict[str, Any]):
controller_address = app_settings.controller_address
model_name = payload["model"]
worker_addr = await get_worker_address(model_name)
embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload)
return json.loads(embedding)
### GENERAL API - NOT OPENAI COMPATIBLE ###
@app.post("/api/v1/token_check")
async def count_tokens(request: APITokenCheckRequest):
"""
Checks the token count for each message in your list
This is not part of the OpenAI API spec.
"""
checkedList = []
for item in request.prompts:
worker_addr = await get_worker_address(item.model)
context_len = await fetch_remote(
worker_addr + "/model_details",
{"prompt": item.prompt, "model": item.model},
"context_length",
)
token_num = await fetch_remote(
worker_addr + "/count_token",
{"prompt": item.prompt, "model": item.model},
"count",
)
can_fit = True
if token_num + item.max_tokens > context_len:
can_fit = False
checkedList.append(
APITokenCheckResponseItem(
fits=can_fit, contextLength=context_len, tokenCount=token_num
)
)
return APITokenCheckResponse(prompts=checkedList)
@app.post("/api/v1/chat/completions")
async def create_chat_completion(request: APIChatCompletionRequest):
"""Creates a completion for the chat message"""
error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret
error_check_ret = check_requests(request)
if error_check_ret is not None:
return error_check_ret
worker_addr = await get_worker_address(request.model)
gen_params = await get_gen_params(
request.model,
worker_addr,
request.messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens,
echo=False,
stop=request.stop,
)
if request.repetition_penalty is not None:
gen_params["repetition_penalty"] = request.repetition_penalty
gen_params["max_new_tokens"] = await check_length(
request,
gen_params["prompt"],
gen_params["max_new_tokens"],
worker_addr,
)
if request.stream:
generator = chat_completion_stream_generator(
request.model, gen_params, request.n, worker_addr
)
return StreamingResponse(generator, media_type="text/event-stream")
choices = []
chat_completions = []
for i in range(request.n):
content = asyncio.create_task(generate_completion(gen_params, worker_addr))
chat_completions.append(content)
try:
all_tasks = await asyncio.gather(*chat_completions)
except Exception as e:
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e))
usage = UsageInfo()
for i, content in enumerate(all_tasks):
if content["error_code"] != 0:
return create_error_response(content["error_code"], content["text"])
choices.append(
ChatCompletionResponseChoice(
index=i,
message=ChatMessage(role="assistant", content=content["text"]),
finish_reason=content.get("finish_reason", "stop"),
)
)
task_usage = UsageInfo.parse_obj(content["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage)
### END GENERAL API - NOT OPENAI COMPATIBLE ###
def create_openai_api_server():
parser = argparse.ArgumentParser(
description="FastChat ChatGPT-Compatible RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--controller-address", type=str, default="http://localhost:21001"
)
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument(
"--api-keys",
type=lambda s: s.split(","),
help="Optional list of comma separated API keys",
)
parser.add_argument(
"--ssl",
action="store_true",
required=False,
default=False,
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.",
)
args = parser.parse_args()
app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)
app_settings.controller_address = args.controller_address
app_settings.api_keys = args.api_keys
logger.info(f"args: {args}")
return args
if __name__ == "__main__":
args = create_openai_api_server()
if args.ssl:
uvicorn.run(
app,
host=args.host,
port=args.port,
log_level="info",
ssl_keyfile=os.environ["SSL_KEYFILE"],
ssl_certfile=os.environ["SSL_CERTFILE"],
)
else:
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
原理
这段代码实现了一个基于FastAPI的控制器(Controller),用于管理分布式的工作节点(Workers)并向客户端提供工作节点的地址。
技术原理如下:
1. 控制器(Controller)使用FastAPI框架构建Web应用程序,通过HTTP协议与客户端进行通信。
2. 控制器(Controller)使用多线程来定期检查工作节点的心跳信息,以判断工作节点是否处于活动状态。
3. 控制器(Controller)维护一个工作节点信息字典(`worker_info`),记录每个工作节点的模型名称、速度、队列长度等信息。
4. 客户端可以通过向控制器(Controller)的API发送请求来注册工作节点、获取可用模型列表、获取工作节点的地址等。
5. 控制器(Controller)根据调度方法(`dispatch_method`)选择一个合适的工作节点,并将其地址返回给客户端。
6. 控制器(Controller)还提供了模拟工作节点的API接口,用于生成数据流并向客户端传输数据。
7. 控制器(Controller)通过定期接收工作节点的心跳信息来判断工作节点的活动状态,并根据心跳信息的过期时间移除失效的工作节点。
8. 控制器(Controller)还处理了没有可用工作节点和工作节点超时的情况,并返回相应的错误信息给客户端。
总体来说,这段代码实现了一个基于FastAPI的控制器(Controller),通过管理工作节点的注册、状态更新和调度,实现了分布式任务的管理和分发。