JARVIS项目源码分析 - awesome_chat.py代码分析3

1、chat_huggingface 函数调用任务执行run_task函数的进程

 process = multiprocessing.Process(target=run_task, args=(input, task, d))   #此处进程会更新字典d的数据
 process.start()
 processes.append(process)   #processes是存储启动的进程列表

2、run_task函数功能:执行任务,先要遍历task的先决任务列表,选择适合的模型执行任务,返回结果

def run_task(input, command, results):
    id = command["id"]
    args = command["args"]
    task = command["task"]
    deps = command["dep"]
    #task的先决任务列表
    if deps[0] != -1:
        dep_tasks = [results[dep] for dep in deps]
    else:
        dep_tasks = []
   
    logger.debug(f"Run task: {id} - {task}")
    logger.debug("Deps: " + json.dumps(dep_tasks))

    #task有先决任务,根据image、audio、text分类
    if deps[0] != -1:
        if "image" in args and "<GENERATED>-" in args["image"]:
            # HuggingGPT将先决任务生成的资源标识为<resource-task_id>,其中task_id是先决任务的任务id,一下切割出task_id
            resource_id = int(args["image"].split("-")[1])
            if "generated image" in results[resource_id]["inference result"]:
                args["image"] = results[resource_id]["inference result"]["generated image"]
        if "audio" in args and "<GENERATED>-" in args["audio"]:
            resource_id = int(args["audio"].split("-")[1])
            if "generated audio" in results[resource_id]["inference result"]:
                args["audio"] = results[resource_id]["inference result"]["generated audio"]
        if "text" in args and "<GENERATED>-" in args["text"]:
            resource_id = int(args["text"].split("-")[1])
            if "generated text" in results[resource_id]["inference result"]:
                args["text"] = results[resource_id]["inference result"]["generated text"]

    text = image = audio = None
    #处理先决任务生成的text,image,audio
    for dep_task in dep_tasks:
        #从结果中检测先决任务的生成文本
        if "generated text" in dep_task["inference result"]:
            #先决任务生成的generated text结果赋值给text
            text = dep_task["inference result"]["generated text"]
            logger.debug("Detect the generated text of dependency task (from results):" + text)
        #如果没有生成,在任务参数中检测先决任务的文本
        elif "text" in dep_task["task"]["args"]:
            text = dep_task["task"]["args"]["text"]
            logger.debug("Detect the text of dependency task (from args): " + text)
        #从结果中检测先决任务的生成image
        if "generated image" in dep_task["inference result"]:
            image = dep_task["inference result"]["generated image"]
            logger.debug("Detect the generated image of dependency task (from results): " + image)
        #如果没有生成,在任务参数中检测先决任务的image
        elif "image" in dep_task["task"]["args"]:
            image = dep_task["task"]["args"]["image"]
            logger.debug("Detect the image of dependency task (from args): " + image)
        #从结果中检测先决任务的生成audio
        if "generated audio" in dep_task["inference result"]:
            audio = dep_task["inference result"]["generated audio"]
            logger.debug("Detect the generated audio of dependency task (from results): " + audio)
        #如果没有生成,在任务参数中检测先决任务的audio
        elif "audio" in dep_task["task"]["args"]:
            audio = dep_task["task"]["args"]["audio"]
            logger.debug("Detect the audio of dependency task (from args): " + audio)

    #把生成的image、audio、text赋值给任务参数中对应字段的值
    if "image" in args and "<GENERATED>" in args["image"]:
        if image:
            args["image"] = image
    if "audio" in args and "<GENERATED>" in args["audio"]:
        if audio:
            args["audio"] = audio
    if "text" in args and "<GENERATED>" in args["text"]:
        if text:
            args["text"] = text

    for resource in ["image", "audio"]:
        if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
            args[resource] = f"public/{args[resource]}"
   
    #如果task字段的值包含"-text-to-image" 并且 args中没有"text",需要生成text
    if "-text-to-image" in command['task'] and "text" not in args:
        logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
        control = task.split("-")[0]
        #给task任务做分类
        if control == "seg":
            task = "image-segmentation"     #图像分割
            command['task'] = task
        elif control == "depth":
            task = "depth-estimation"       #深度估计  
            command['task'] = task
        else:                               #其他
            task = f"{control}-control"

    command["args"] = args
    logger.debug(f"parsed task: {command}")

    #根据模型的任务类型筛选模型,以"-text-to-image"结尾的任务
    if task.endswith("-text-to-image"):
        control = task.split("-")[0]
        best_model_id = f"lllyasviel/sd-controlnet-{control}"
        hosted_on = "local"
        reason = "ControlNet is the best model for this task."
        choose = {"id": best_model_id, "reason": reason}
        logger.debug(f"chosen model: {choose}")
    #以"-control"结尾的任务
    elif task.endswith("-control"):
        best_model_id = task
        hosted_on = "local"
        reason = "ControlNet tools"
        choose = {"id": best_model_id, "reason": reason}
        logger.debug(f"chosen model: {choose}")
    #其他任务
    else:
        if task not in MODELS_MAP:
            logger.warning(f"no available models on {task} task.")
            record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
            inference_result = {"error": f"{command['task']} not found in available tasks."}
            results[id] = colloct_result(command, choose, inference_result)
            return False
        #候选人
        candidates = MODELS_MAP[task][:10]
        all_avaliable_models = get_avaliable_models(candidates)
        #所有可用的local和huggingface的模型列表
        all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]  
        logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")

        #可用的模型列表为空,记录日志,返回结果
        if len(all_avaliable_model_ids) == 0:
            logger.warning(f"no available models on {command['task']}")
            record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
            inference_result = {"error": f"no available models on {command['task']} task."}
            results[id] = colloct_result(command, "", inference_result)
            return False

        #可用的模型列表里只有一个模型  
        if len(all_avaliable_model_ids) == 1:
            best_model_id = all_avaliable_model_ids[0]
            hosted_on = "unknown"
            reason = "Only one model available."
            choose = {"id": best_model_id, "reason": reason}
            logger.debug(f"chosen model: {choose}")

        #可用的模型列表里只有多个模型
        else:
            cand_models_info = [
                {
                    "id": model["id"],
                    "inference endpoint": all_avaliable_models.get(
                        "local" if model["id"] in all_avaliable_models["local"] else "huggingface"
                    ),
                    "likes": model.get("likes"),
                    "description": model.get("description", "")[:100],
                    "language": model.get("language"),
                    "tags": model.get("tags"),
                }
                for model in candidates
                if model["id"] in all_avaliable_model_ids
            ]

            #选择模型
            choose_str = choose_model(input, command, cand_models_info)
            logger.debug(f"chosen model: {choose_str}")
            try:
                choose = json.loads(choose_str)
                reason = choose["reason"]
                best_model_id = choose["id"]
                hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
            except Exception as e:
                logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
                choose_str = find_json(choose_str)
                best_model_id, reason, choose  = get_id_reason(choose_str)
                hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
    inference_result = model_inference(best_model_id, args, hosted_on, command['task'])

    #如果推理结果中包含错误error,记录results和日志,返回结果False
    if "error" in inference_result:
        logger.warning(f"Inference error: {inference_result['error']}")
        record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
        results[id] = colloct_result(command, choose, inference_result)
        return False
   
    #否则,results,返回结果True
    results[id] = colloct_result(command, choose, inference_result)
    return True

3、run_task函数调用record_case 函数:将记录写入log文件,成功写入log_success.jsonl,失败写入log_fail.jsonl

def record_case(success, **args):
    if success:
        f = open("log_success.jsonl", "a")
    else:
        f = open("log_fail.jsonl", "a")
    log = args
    f.write(json.dumps(log) + "\n")
    f.close()

4、run_task函数调用 colloct_result,收集结果

def colloct_result(command, choose, inference_result):
    result = {"task": command}
    result["inference result"] = inference_result
    result["choose model result"] = choose
    logger.debug(f"inference result: {inference_result}")
    return result

5、run_task函数调用get_avaliable_models:获取本地或者远端可用的模型

def get_avaliable_models(candidates, topk=5):
    all_available_models = {"local": [], "huggingface": []}
    processes = []
    result_queue = multiprocessing.Queue()  #创建一个结果队列

    #循环遍历所有模型
    for candidate in candidates:
        model_id = candidate["id"]

        #处理远端模型
        if inference_mode != "local":
            huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
            process = multiprocessing.Process(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
            processes.append(process)
            process.start()
       
        #处理本地模型
        if inference_mode != "huggingface":
            localStatusUrl = f"{Model_Server}/status/{model_id}"
            process = multiprocessing.Process(target=get_model_status, args=(model_id, localStatusUrl, {}, result_queue))
            processes.append(process)
            process.start()
       
    result_count = len(processes)
    while result_count:
        model_id, status, endpoint_type = result_queue.get()
        if status and model_id not in all_available_models:
            all_available_models[endpoint_type].append(model_id)
        if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
            break
        result_count -= 1

    #等待所有子进程完成之后,在继续执行主进程
    for process in processes:
        process.join()

    return all_available_models

6、 get_avaliable_models函数调用 get_model_status :获取本地或者huggingface远端服务的状态

def get_model_status(model_id, url, headers, queue):
    endpoint_type = "huggingface" if "huggingface" in url else "local"
    if "huggingface" in url:  #请求远程服务
        r = requests.get(url, headers=headers, proxies=PROXY)
    else:   #本地
        r = requests.get(url)
    #请求返回状态码是200的,加入队列
    if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
        queue.put((model_id, True, endpoint_type))
    else:
        queue.put((model_id, False, None))

7、run_task函数调用get_id_reason:获取选择模型的id和reason两个字段

def get_id_reason(choose_str):
    reason = field_extract(choose_str, "reason")
    id = field_extract(choose_str, "id")
    choose = {"id": id, "reason": reason}
    return id.strip(), reason.strip(), choose

    调用field_extract函数: 试用正则表达式完成字符串匹配

def field_extract(s, field):
    #用正则表达式匹配
    try:
        field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
        extracted = field_rep.search(s).group(1).replace("\"", "\'")
    except:
        field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
        extracted = field_rep.search(s).group(1).replace("\"", "\'")
    return extracted

8、run_task函数调用colloct_result

def colloct_result(command, choose, inference_result):
    result = {"task": command}
    result["inference result"] = inference_result
    result["choose model result"] = choose
    logger.debug(f"inference result: {inference_result}")
    return result

未完待续

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值