项目实训——大模型篇1:MING模型部署与OpenaiAPI服务的实现1之ModelWorker

在前述文章里面,我们成功地把MING部署在了AutoDL服务器上,但有一个问题在于,MING的开发者们并没有尝试过使用REStfulAPI的形式对大模型进行访问,因此提供的api服务组件也是不全的。

一般而言,fastchat部署流程是这样的

1、在地址A启动controller服务

2、在地址B启动model_worker服务,并在controller上进行注册

3、在地址C启动api_server服务

在这个流程中,用户先访问api_server的接口,发送json信息,api_server解析json信息后,接着向controller发送信息,获得目标模型的model_worker地址,api再向目标模型model_worker所在地址发出请求,获得生成的消息流,再将其返回给用户

MING的开发者写了一套简易的聊天服务,但并没有给MING-MOE提供部署相关的服务。因此我们根据现有代码,写了针对ming-moe模型的一套简易model_worker、controller和api服务

1、MOE架构模型介绍

MoE架构是一种基于条件计算的模型架构,它将大型模型拆分为多个较小的专家模型,每个专家模型负责处理特定的任务或数据子集。在模型运行时,根据输入数据的特性,选择相应的专家模型进行处理。这种架构可以在保证模型性能的同时,显著提高模型的效率和可扩展性。

MING-MOE模型是基于Qwen的Moe模型,这意味着需要Qwen作为对应的模型基座,做一个合并才能正常使用。

2、Model_worker

权重融合

ming/model/builder.py中的load_molora_pretrained_model函数中基于PEFT库合并lora权重,实现了lora权重和基座模型的合并,在初始化时需要进行加载

ming-moe开发者已经实现了与基座模型合并的函数,我们只需要使用就行

 self.tokenizer, self.model, self.context_len, _ = load_molora_pretrained_model(model_path, model_base,
                                                                                       model_name, load_8bit,
                                                                                       None,
                                                                                       use_logit_bias=None,
                                                                                       only_load=None,
                                                                                       expert_selection=None)

初始化与消息生成

初始化部分完成参数传递和权重融合

generate_stream_gate将参数传递给inference.py,再调用model.generate来生成回答文本。这里生成的文本并非由多个token构成的集合,而是一段完整的文本

之后我们要将这段文本包装成json的格式,并以流或json文本的形式返回给API_Server

  def generate_stream_gate(self, params):
        try:
            output = self.generate_stream_func(self.model, self.tokenizer,params, self.device, self.beam_size, self.context_len,args.stream_interval)
            ret = {
                "text": output,
                "error_code": 0,
            }
            yield json.dumps(ret).encode() + b"\0"
        except torch.cuda.OutOfMemoryError as e:
            ret = {
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
                "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
            }
            yield json.dumps(ret).encode() + b"\0"
        except (ValueError, RuntimeError) as e:
            ret = {
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
                "error_code": ErrorCode.INTERNAL_ERROR,
            }
            yield json.dumps(ret).encode() + b"\0"

协程

代码中引入了model_semaphore作为信号量,确保当模型生成文本时,只有一个线程持有信号量,以达成协程之间的同步

部署

使用uvicorn和fastAPI进行部署

model_worker完整代码

import argparse
import asyncio
import dataclasses
import logging
import json
import time

import threading
import uuid

from fastapi import FastAPI, Request, BackgroundTasks
from fastapi.responses import StreamingResponse, JSONResponse
import requests

try:
    from transformers import AutoTokenizer, AutoModelForCausalLM, LlamaTokenizer, AutoModel
except ImportError:
    from transformers import AutoTokenizer, AutoModelForCausalLM, LLaMATokenizer, AutoModel
import torch
import uvicorn

from ming.serve.inference import load_molora_pretrained_model, load_pretrained_model, generate_stream
from ming.conversations import conv_templates, get_default_conv_template, SeparatorStyle

from fastchat.constants import ErrorCode, SERVER_ERROR_MSG, WORKER_HEART_BEAT_INTERVAL

GB = 1 << 30

worker_id = str(uuid.uuid4())[:6]

global_counter = 0

model_semaphore = None
CONTROLLER_HEART_BEAT_EXPIRATION = 90
WORKER_HEART_BEAT_INTERVAL = 30


def pretty_print_semaphore(semaphore):
    if semaphore is None:
        return "None"
    return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"
def heart_beat_worker(controller):
    while True:
        time.sleep(WORKER_HEART_BEAT_INTERVAL)
        controller.send_heart_beat()
class ModelWorker:
    def __init__(self, controller_addr, worker_addr,
                 worker_id, no_register, model_path, model_name, model_base,
                 device, num_gpus, max_gpu_memory, load_8bit=False):
        self.controller_addr = controller_addr
        self.worker_addr = worker_addr
        self.worker_id = worker_id
        self.beam_size = 1
        self.conv = conv_templates["qwen"].copy()
        if model_path.endswith("/"):
            model_path = model_path[:-1]
        self.model_name = model_name or model_path.split("/")[-1]
        self.device = device


        self.tokenizer, self.model, self.context_len, _ = load_molora_pretrained_model(model_path, model_base,
                                                                                       model_name, load_8bit,
                                                                                       None,
                                                                                       use_logit_bias=None,
                                                                                       only_load=None,
                                                                                       expert_selection=None)
        if hasattr(self.model.config, "max_sequence_length"):
            self.context_len = self.model.config.max_sequence_length
        elif hasattr(self.model.config, "max_position_embeddings"):
            self.context_len = self.model.config.max_position_embeddings
        else:
            self.context_len = 3072
        self.generate_stream_func = generate_stream

        if not no_register:
            self.register_to_controller()
            self.heart_beat_thread = threading.Thread(
                target=heart_beat_worker, args=(self,))
            self.heart_beat_thread.start()
    def register_to_controller(self):
        print("Register to controller")
        url = self.controller_addr + "/register_worker"
        data = {
            "worker_name": self.worker_addr,
            "check_heart_beat": True,
            "worker_status": self.get_status()
        }
        r = requests.post(url, json=data)
        assert r.status_code == 200
    def send_heart_beat(self):
        print(f"Send heart beat. Models: {[self.model_name]}. "
                    f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
                    f"global_counter: {global_counter}")

        url = self.controller_addr + "/receive_heart_beat"
        while True:
            try:
                ret = requests.post(url, json={
                    "worker_name": self.worker_addr,
                    "queue_length": self.get_queue_length()}, timeout=5)
                exist = ret.json()["exist"]
                break
            except requests.exceptions.RequestException as e:
                print(f"heart beat error: {e}")
            time.sleep(5)
        if not exist:
            self.register_to_controller()

    def get_queue_length(self):
        if model_semaphore is None or model_semaphore._value is None or model_semaphore._waiters is None:
            return 0
        else:
            return args.limit_model_concurrency - model_semaphore._value + len(
                model_semaphore._waiters)

    def get_status(self):
        return {
            "model_names": [self.model_name],
            "speed": 1,
            "queue_length": self.get_queue_length(),
        }
    def get_conv_template(self):
        return {"conv": self.conv}

    def generate_stream_gate(self, params):
        try:
            output = self.generate_stream_func(self.model, self.tokenizer,params, self.device, self.beam_size, self.context_len,args.stream_interval)
            ret = {
                "text": output,
                "error_code": 0,
            }
            yield json.dumps(ret).encode() + b"\0"
        except torch.cuda.OutOfMemoryError as e:
            ret = {
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
                "error_code": ErrorCode.CUDA_OUT_OF_MEMORY,
            }
            yield json.dumps(ret).encode() + b"\0"
        except (ValueError, RuntimeError) as e:
            ret = {
                "text": f"{SERVER_ERROR_MSG}\n\n({e})",
                "error_code": ErrorCode.INTERNAL_ERROR,
            }
            yield json.dumps(ret).encode() + b"\0"

app = FastAPI()

def release_model_semaphore():
    model_semaphore.release()

@app.post("/worker_generate_stream")
async def api_generate_stream(request: Request):
    global model_semaphore, global_counter
    global_counter += 1
    params = await request.json()

    if model_semaphore is None:
        model_semaphore = asyncio.Semaphore(args.limit_model_concurrency)
    await model_semaphore.acquire()
    generator = worker.generate_stream_gate(params)
    background_tasks = BackgroundTasks()
    background_tasks.add_task(release_model_semaphore)
    return StreamingResponse(generator, background=background_tasks)
@app.post("/worker_get_status")
async def api_get_status(request: Request):
    return worker.get_status()


@app.post("/worker_get_conv_template")
async def api_get_conv(request: Request):
    return worker.get_conv_template()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--host", type=str, default="localhost")
    parser.add_argument("--port", type=int, default=21002)
    parser.add_argument("--worker-address", type=str,
                        default="http://localhost:21002")
    parser.add_argument("--controller-address", type=str,
                        default="http://localhost:21001")
    parser.add_argument("--model-path", type=str, default="/root/autodl-tmp/MING-MOE-4B",
                        help="The path to the weights")
    parser.add_argument("--model-name", type=str,
                        help="Optional name")
    parser.add_argument("--model-base", type=str, default="/root/autodl-tmp/Qwen1.5-4B-Chat",
                        help="The base model")
    parser.add_argument("--device", type=str, choices=["cpu", "cuda", "mps"], default="cuda")
    parser.add_argument("--num-gpus", type=int, default=1)
    parser.add_argument("--max-gpu-memory", type=str, default="13GiB")
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--limit-model-concurrency", type=int, default=5)
    parser.add_argument("--stream-interval", type=int, default=2)
    parser.add_argument("--no-register", action="store_true")
    args = parser.parse_args()
    print(f"args: {args}")
    worker = ModelWorker(args.controller_address,
                         args.worker_address,
                         worker_id,
                         args.no_register,
                         args.model_path,
                         args.model_name,
                         args.model_base,
                         args.device,
                         args.num_gpus,
                         args.max_gpu_memory,
                         args.load_8bit)
    uvicorn.run(app, host=args.host, port=args.port, log_level="info")

  • 5
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值