【小工具】基于LLAMA 的Python简易服务端

2 篇文章 0 订阅

功能介绍

基于LLAMA,搭建简易版本Python服务端,以提供API服务。

前期准备

  1. 下载LLAMA并根据指示获取权重。地址:https://github.com/facebookresearch/llama
  2. 安装必要的Python库
    • bottle: 提供web服务
    • fairscale:模型并行库
    • pytorch:机器学习库
  3. 如有必要,可以调整使用那几张显卡 export CUDA_VISIBLE_DEVICES=0,1,2...

实现思路

要点(详见代码):

  1. 配置环境
  • MASTER_ADDR:主节点地址。一般为127.0.0.1
  • MASTER_PORT:主节点端口。一般1024-65535
  • LOCAL_RANK/RANK: 第几个节点。程序中动态设置
  • WORLD_SIZE: 使用多少张卡。取决于模型。7B->1, 13B->2, 60B->8
  1. 对于13B和60B,需要使用多张卡,一张卡对应一个线/进程。在生成的时候,这些线/进程需要并行启用。代码中的MultiProcessLlama负责处理这些线/进程。
  2. 使用全局字典,记录对话上下文信息。
  3. 运行 python <server-name>.py --llama-version 7B

实现的API:

  • POST /api/initchat: 初始化对话环节
    • 请求参数
      • content: 系统prompt
    • 返回响应(json)
      • uuid:用户标识
  • POST /api/chat: 发起对话
    • 请求参数
      • uuid: 用户标识,来自/api/initchat
      • content: 用户的对话内容
    • 返回响应(json)
      • uuid:用户标识,来自/api/initchat
      • status: 状态。0:失败;1:成功
      • response: LLAMA的回答
  • POST /api/reset: 重置对话状态。仅保留系统prompt。
    • 请求参数
      • uuid: 用户标识,来自/api/initchat
    • 返回响应(json)
      • uuid:用户标识,来自/api/initchat
      • status: 状态。0:失败;1:成功
  • POST:/api/chat_once: 只聊一句。等价于:/api/chat and /api/reset
    • 请求参数
      • uuid: 用户标识,来自/api/initchat
      • content: 用户的对话内容
    • 返回响应(json)
      • uuid:用户标识,来自/api/initchat
      • status: 状态。0:失败;1:成功
      • response: LLAMA的回答

具体代码

import argparse
import multiprocessing
import os
import uuid as libuuid
from multiprocessing import Queue

import bottle
import torch
import torch.distributed
from bottle import get, post, request
from fairscale.nn.model_parallel.initialize import initialize_model_parallel

from llama import Llama

_arg = argparse.ArgumentParser("Server LLAMA Chat Web")
_arg.add_argument("--llama-version", type=str,
                  default="7B", help="LLAMA version, avable [7B, 13B, 70B]")
args = _arg.parse_args()

if not args.llama_version in ["7B", "13B", "70B"]:
    raise ValueError("LLaMA version not found. support 7B, 13B, 70B")


class MultiProcessLlama:
    def __init__(self, world_size):
        self.in_queue_list = [
            Queue(8)
            for _ in range(world_size)
        ]
        self.out_queue = Queue(8)

        self.world_size = world_size
        self.process_list = []
        print("init Done")

    def chat_completion(self, *args, **kwargs):
        for ique in self.in_queue_list:
            # print("Call chat_completion")
            ique.put((
                "chat_completion",
                args,
                kwargs,
            ))

        out = self.out_queue.get()
        return out

    def start(self, *args, **kwargs):

        def __loop(rank: int, world_size: int, in_queue: Queue, out_queue: Queue, args, kwargs):
            if rank == 0:
                assert out_queue != None

            os.environ["MASTER_ADDR"] = "127.0.0.1"
            os.environ["MASTER_PORT"] = "65288"
            # os.environ["RANK"] = str(rank)
            os.environ["LOCAL_RANK"] = str(rank)
            # os.environ["WORLD_SIZE"] = str(world_size)
            torch.distributed.init_process_group(
                "nccl",
                rank=rank,
                world_size=world_size,
            )
            initialize_model_parallel(world_size)

            generator = Llama.build(
                *args,
                **kwargs,
            )

            while True:
                # print(f"[{rank}] in queue wait")
                cmd, args, kwargs = in_queue.get()
                # print(f"[{rank}] in queue get", cmd)
                out = None
                if cmd is None:
                    break
                if cmd == "chat_completion":
                    out = generator.chat_completion(*args, **kwargs)
                elif cmd == "text_completion":
                    out = generator.text_completion(*args, **kwargs)
                else:
                    print("Warnning, unknown command", cmd)
                # all responses are the same. write to rank 0 only
                if rank == 0:
                    out_queue.put(out)

                # print(f"[{rank}] {cmd} {args}, {kwargs} => {out}")

        for i in range(self.world_size):
            pi = multiprocessing.Process(
                target=__loop,
                args=(
                    i,
                    self.world_size,
                    self.in_queue_list[i],
                    self.out_queue if i == 0 else None,
                    args,
                    kwargs,
                )
            )
            self.process_list.append(pi)
            pi.start()

    def join(self):
        for _ in range(4):
            # put 4 times to prevent missing queue
            for que in self.in_queue_list:
                que.put((None, None, None))

        for pi in self.process_list:
            pi.join()


chat_uuid_dict = dict()


def generate_chat(generator, chat_info):
    if generator is None:
        return [{"role": "assistant", "content": "(ᗜ_ᗜ)"}]

    return generator.chat_completion(
        [chat_info],  # type: ignore
        max_gen_len=None,
        temperature=0.6,
        top_p=0.99,
    )[0]


Global_generator = None

app = bottle.Bottle()


@app.route("/")
def index():
    return bottle.template("./web/statics/index.html")


@app.route("/statics/<filename:path>")
def serve_static(filename):
    return bottle.static_file(filename, "web/statics")


@app.post("/api/close")
def api_close():
    uuid = request.forms.get("uuid", "")
    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}
    del chat_uuid_dict[uuid]
    return {"uuid": uuid, "status": 1}


@app.post("/api/chat")
def api_chat():
    uuid = request.forms.get("uuid", "")
    content = request.forms.get("content", "")
    if content == "":
        return {
            "uuid": uuid,
            "status": 1,
            "response": "(ᗜ_ᗜ)",
        }

    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}

    chat_hist = chat_uuid_dict[uuid]

    chat_hist.append({
        "role": "user",
        "content": content,
    })

    result = generate_chat(Global_generator, chat_hist)
    answer = result["generation"]['content']
    chat_hist.append({
        "role": "assistant",
        "content": answer
    })

    return {
        "uuid": uuid,
        "status": 1,
        "response": answer,
    }


@app.post("/api/initchat")
def api_initchat():
    content = request.forms.get("content", "Feel free to answer the question.")

    while True:
        uuid = str(libuuid.uuid4())
        if not uuid in chat_uuid_dict:
            chat_uuid_dict[uuid] = [{
                "role": "system",
                "content": content,
            }]
            break
    return {
        "uuid": uuid,
    }


@app.post("/api/chat_once")
def api_initchat():
    uuid = request.forms.get("uuid", "")
    content = request.forms.get("content", "")
    if content == "":
        return {
            "uuid": uuid,
            "status": 1,
            "response": "(ᗜ_ᗜ)",
        }

    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}

    chat_hist = []
    chat_hist.append(chat_uuid_dict[uuid][0])

    chat_hist.append({
        "role": "user",
        "content": content,
    })

    result = generate_chat(Global_generator, chat_hist)
    answer = result["generation"]['content']

    return {
        "uuid": uuid,
        "status": 1,
        "response": answer,
    }


@app.post("/api/reset")
def api_reset():
    uuid = request.forms.get("uuid", "")
    if not uuid in chat_uuid_dict:
        return {"uuid": uuid, "status": 0}
    chat_hist = chat_uuid_dict[uuid]  # type: list
    init_msg = chat_hist[0]
    chat_hist.clear()
    chat_hist.append(init_msg)
    return {"uuid": uuid, "status": 1}


if __name__ == "__main__":
    print("System loadding...")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    if Global_generator is None:
        if args.llama_version == "7B":
            os.environ["MASTER_ADDR"] = "127.0.0.1"
            os.environ["MASTER_PORT"] = "65288"
            os.environ["RANK"] = "0"
            os.environ["WORLD_SIZE"] = "1"
            Global_generator = Llama.build(
                ckpt_dir="./downloads/llama-2-7b-chat",
                tokenizer_path="./downloads/tokenizer.model",
                max_seq_len=9000,
                max_batch_size=1,
                model_parallel_size=1,
            )
        elif args.llama_version == "13B":
            Global_generator = MultiProcessLlama(2)
            Global_generator.start(
                ckpt_dir="./downloads/llama-2-13b-chat",
                tokenizer_path="./downloads/tokenizer.model",
                max_seq_len=2048,
                max_batch_size=1,
                model_parallel_size=2,
            )

        elif args.llama_version == "70B":
            print("Use torch run")
            Global_generator = MultiProcessLlama(8)
            Global_generator.start(
                ckpt_dir="./downloads/llama-2-70b-chat",
                tokenizer_path="./downloads/tokenizer.model",
                max_seq_len=2048,
                max_batch_size=1,
                model_parallel_size=8,
            )

    print("Init with", args.llama_version)
    app.run(host='0.0.0.0', port=8088, debug=False, reloader=False)
    if args.llama_version != "7B":
        try:
            Global_generator.join()
        except Exception as e:
            print(e)

远程调用

客户端的工具类。

import requests
import http

class ChatLlama:
    def __init__(self, addr, content: str = ""):
        self.addr = addr
        self.chat_uuid = self.init_chat(content)

    def init_chat(self, content: str):
        resp = requests.post(f"{self.addr}/api/initchat", data={ "content": content })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError(resp.status_code)

        chat_uuid = resp.json()["uuid"]
        resp.close()

        print("init UUID", chat_uuid)
        return chat_uuid

    def chat_request(self, context) -> str:
        resp = requests.post(f"{self.addr}/api/chat", data={
            "uuid": self.chat_uuid,
            "content": context,
        })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError("HTTP error", resp.status_code)

        ans = resp.json()
        if ans["status"] == 0:
            raise ValueError("UUID does not exist")

        return ans["response"]

    def chat_once(self, context) -> str:
        resp = requests.post(f"{self.addr}/api/chat_once", data={
            "uuid": self.chat_uuid,
            "content": context,
        })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError("HTTP error", resp.status_code)

        ans = resp.json()
        if ans["status"] == 0:
            raise ValueError("UUID does not exist")

        return ans["response"]
    
    def chat_reset(self) -> bool:
        resp = requests.post(f"{self.addr}/api/reset", data={
            "uuid": self.chat_uuid,
        })

        if resp.status_code != http.HTTPStatus.OK:
            raise ValueError("HTTP error", resp.status_code)

        ans = resp.json()
        return ans["status"] == 1

    def close_chat(self):
        resp = requests.post(f"{self.addr}/api/close", data={
            "uuid": self.chat_uuid,
        })

        resp.close()
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值