功能介绍
基于LLAMA,搭建简易版本Python服务端,以提供API服务。
前期准备
- 下载LLAMA并根据指示获取权重。地址:https://github.com/facebookresearch/llama
- 安装必要的Python库
- bottle: 提供web服务
- fairscale:模型并行库
- pytorch:机器学习库
- 如有必要,可以调整使用那几张显卡
export CUDA_VISIBLE_DEVICES=0,1,2...
实现思路
要点(详见代码):
- 配置环境
- MASTER_ADDR:主节点地址。一般为127.0.0.1
- MASTER_PORT:主节点端口。一般1024-65535
- LOCAL_RANK/RANK: 第几个节点。程序中动态设置
- WORLD_SIZE: 使用多少张卡。取决于模型。7B->1, 13B->2, 60B->8
- 对于13B和60B,需要使用多张卡,一张卡对应一个线/进程。在生成的时候,这些线/进程需要并行启用。代码中的
MultiProcessLlama
负责处理这些线/进程。 - 使用全局字典,记录对话上下文信息。
- 运行
python <server-name>.py --llama-version 7B
。
实现的API:
- POST
/api/initchat
: 初始化对话环节- 请求参数
- content: 系统prompt
- 返回响应(json)
- uuid:用户标识
- 请求参数
- POST
/api/chat
: 发起对话- 请求参数
- uuid: 用户标识,来自
/api/initchat
- content: 用户的对话内容
- uuid: 用户标识,来自
- 返回响应(json)
- uuid:用户标识,来自
/api/initchat
- status: 状态。0:失败;1:成功
- response: LLAMA的回答
- uuid:用户标识,来自
- 请求参数
- POST
/api/reset
: 重置对话状态。仅保留系统prompt。- 请求参数
- uuid: 用户标识,来自
/api/initchat
- uuid: 用户标识,来自
- 返回响应(json)
- uuid:用户标识,来自
/api/initchat
- status: 状态。0:失败;1:成功
- uuid:用户标识,来自
- 请求参数
- POST:
/api/chat_once
: 只聊一句。等价于:/api/chat
and/api/reset
- 请求参数
- uuid: 用户标识,来自
/api/initchat
- content: 用户的对话内容
- uuid: 用户标识,来自
- 返回响应(json)
- uuid:用户标识,来自
/api/initchat
- status: 状态。0:失败;1:成功
- response: LLAMA的回答
- uuid:用户标识,来自
- 请求参数
具体代码
import argparse
import multiprocessing
import os
import uuid as libuuid
from multiprocessing import Queue
import bottle
import torch
import torch.distributed
from bottle import get, post, request
from fairscale.nn.model_parallel.initialize import initialize_model_parallel
from llama import Llama
_arg = argparse.ArgumentParser("Server LLAMA Chat Web")
_arg.add_argument("--llama-version", type=str,
default="7B", help="LLAMA version, avable [7B, 13B, 70B]")
args = _arg.parse_args()
if not args.llama_version in ["7B", "13B", "70B"]:
raise ValueError("LLaMA version not found. support 7B, 13B, 70B")
class MultiProcessLlama:
def __init__(self, world_size):
self.in_queue_list = [
Queue(8)
for _ in range(world_size)
]
self.out_queue = Queue(8)
self.world_size = world_size
self.process_list = []
print("init Done")
def chat_completion(self, *args, **kwargs):
for ique in self.in_queue_list:
# print("Call chat_completion")
ique.put((
"chat_completion",
args,
kwargs,
))
out = self.out_queue.get()
return out
def start(self, *args, **kwargs):
def __loop(rank: int, world_size: int, in_queue: Queue, out_queue: Queue, args, kwargs):
if rank == 0:
assert out_queue != None
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "65288"
# os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
# os.environ["WORLD_SIZE"] = str(world_size)
torch.distributed.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
)
initialize_model_parallel(world_size)
generator = Llama.build(
*args,
**kwargs,
)
while True:
# print(f"[{rank}] in queue wait")
cmd, args, kwargs = in_queue.get()
# print(f"[{rank}] in queue get", cmd)
out = None
if cmd is None:
break
if cmd == "chat_completion":
out = generator.chat_completion(*args, **kwargs)
elif cmd == "text_completion":
out = generator.text_completion(*args, **kwargs)
else:
print("Warnning, unknown command", cmd)
# all responses are the same. write to rank 0 only
if rank == 0:
out_queue.put(out)
# print(f"[{rank}] {cmd} {args}, {kwargs} => {out}")
for i in range(self.world_size):
pi = multiprocessing.Process(
target=__loop,
args=(
i,
self.world_size,
self.in_queue_list[i],
self.out_queue if i == 0 else None,
args,
kwargs,
)
)
self.process_list.append(pi)
pi.start()
def join(self):
for _ in range(4):
# put 4 times to prevent missing queue
for que in self.in_queue_list:
que.put((None, None, None))
for pi in self.process_list:
pi.join()
chat_uuid_dict = dict()
def generate_chat(generator, chat_info):
if generator is None:
return [{"role": "assistant", "content": "(ᗜ_ᗜ)"}]
return generator.chat_completion(
[chat_info], # type: ignore
max_gen_len=None,
temperature=0.6,
top_p=0.99,
)[0]
Global_generator = None
app = bottle.Bottle()
@app.route("/")
def index():
return bottle.template("./web/statics/index.html")
@app.route("/statics/<filename:path>")
def serve_static(filename):
return bottle.static_file(filename, "web/statics")
@app.post("/api/close")
def api_close():
uuid = request.forms.get("uuid", "")
if not uuid in chat_uuid_dict:
return {"uuid": uuid, "status": 0}
del chat_uuid_dict[uuid]
return {"uuid": uuid, "status": 1}
@app.post("/api/chat")
def api_chat():
uuid = request.forms.get("uuid", "")
content = request.forms.get("content", "")
if content == "":
return {
"uuid": uuid,
"status": 1,
"response": "(ᗜ_ᗜ)",
}
if not uuid in chat_uuid_dict:
return {"uuid": uuid, "status": 0}
chat_hist = chat_uuid_dict[uuid]
chat_hist.append({
"role": "user",
"content": content,
})
result = generate_chat(Global_generator, chat_hist)
answer = result["generation"]['content']
chat_hist.append({
"role": "assistant",
"content": answer
})
return {
"uuid": uuid,
"status": 1,
"response": answer,
}
@app.post("/api/initchat")
def api_initchat():
content = request.forms.get("content", "Feel free to answer the question.")
while True:
uuid = str(libuuid.uuid4())
if not uuid in chat_uuid_dict:
chat_uuid_dict[uuid] = [{
"role": "system",
"content": content,
}]
break
return {
"uuid": uuid,
}
@app.post("/api/chat_once")
def api_initchat():
uuid = request.forms.get("uuid", "")
content = request.forms.get("content", "")
if content == "":
return {
"uuid": uuid,
"status": 1,
"response": "(ᗜ_ᗜ)",
}
if not uuid in chat_uuid_dict:
return {"uuid": uuid, "status": 0}
chat_hist = []
chat_hist.append(chat_uuid_dict[uuid][0])
chat_hist.append({
"role": "user",
"content": content,
})
result = generate_chat(Global_generator, chat_hist)
answer = result["generation"]['content']
return {
"uuid": uuid,
"status": 1,
"response": answer,
}
@app.post("/api/reset")
def api_reset():
uuid = request.forms.get("uuid", "")
if not uuid in chat_uuid_dict:
return {"uuid": uuid, "status": 0}
chat_hist = chat_uuid_dict[uuid] # type: list
init_msg = chat_hist[0]
chat_hist.clear()
chat_hist.append(init_msg)
return {"uuid": uuid, "status": 1}
if __name__ == "__main__":
print("System loadding...")
local_rank = int(os.environ.get("LOCAL_RANK", 0))
if Global_generator is None:
if args.llama_version == "7B":
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "65288"
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
Global_generator = Llama.build(
ckpt_dir="./downloads/llama-2-7b-chat",
tokenizer_path="./downloads/tokenizer.model",
max_seq_len=9000,
max_batch_size=1,
model_parallel_size=1,
)
elif args.llama_version == "13B":
Global_generator = MultiProcessLlama(2)
Global_generator.start(
ckpt_dir="./downloads/llama-2-13b-chat",
tokenizer_path="./downloads/tokenizer.model",
max_seq_len=2048,
max_batch_size=1,
model_parallel_size=2,
)
elif args.llama_version == "70B":
print("Use torch run")
Global_generator = MultiProcessLlama(8)
Global_generator.start(
ckpt_dir="./downloads/llama-2-70b-chat",
tokenizer_path="./downloads/tokenizer.model",
max_seq_len=2048,
max_batch_size=1,
model_parallel_size=8,
)
print("Init with", args.llama_version)
app.run(host='0.0.0.0', port=8088, debug=False, reloader=False)
if args.llama_version != "7B":
try:
Global_generator.join()
except Exception as e:
print(e)
远程调用
客户端的工具类。
import requests
import http
class ChatLlama:
def __init__(self, addr, content: str = ""):
self.addr = addr
self.chat_uuid = self.init_chat(content)
def init_chat(self, content: str):
resp = requests.post(f"{self.addr}/api/initchat", data={ "content": content })
if resp.status_code != http.HTTPStatus.OK:
raise ValueError(resp.status_code)
chat_uuid = resp.json()["uuid"]
resp.close()
print("init UUID", chat_uuid)
return chat_uuid
def chat_request(self, context) -> str:
resp = requests.post(f"{self.addr}/api/chat", data={
"uuid": self.chat_uuid,
"content": context,
})
if resp.status_code != http.HTTPStatus.OK:
raise ValueError("HTTP error", resp.status_code)
ans = resp.json()
if ans["status"] == 0:
raise ValueError("UUID does not exist")
return ans["response"]
def chat_once(self, context) -> str:
resp = requests.post(f"{self.addr}/api/chat_once", data={
"uuid": self.chat_uuid,
"content": context,
})
if resp.status_code != http.HTTPStatus.OK:
raise ValueError("HTTP error", resp.status_code)
ans = resp.json()
if ans["status"] == 0:
raise ValueError("UUID does not exist")
return ans["response"]
def chat_reset(self) -> bool:
resp = requests.post(f"{self.addr}/api/reset", data={
"uuid": self.chat_uuid,
})
if resp.status_code != http.HTTPStatus.OK:
raise ValueError("HTTP error", resp.status_code)
ans = resp.json()
return ans["status"] == 1
def close_chat(self):
resp = requests.post(f"{self.addr}/api/close", data={
"uuid": self.chat_uuid,
})
resp.close()