1、chat_huggingface 函数调用任务执行run_task函数的进程
process = multiprocessing.Process(target=run_task, args=(input, task, d)) #此处进程会更新字典d的数据
process.start()
processes.append(process) #processes是存储启动的进程列表
2、run_task函数功能:执行任务,先要遍历task的先决任务列表,选择适合的模型执行任务,返回结果
def run_task(input, command, results):
id = command["id"]
args = command["args"]
task = command["task"]
deps = command["dep"]
#task的先决任务列表
if deps[0] != -1:
dep_tasks = [results[dep] for dep in deps]
else:
dep_tasks = []
logger.debug(f"Run task: {id} - {task}")
logger.debug("Deps: " + json.dumps(dep_tasks))
#task有先决任务,根据image、audio、text分类
if deps[0] != -1:
if "image" in args and "<GENERATED>-" in args["image"]:
# HuggingGPT将先决任务生成的资源标识为<resource-task_id>,其中task_id是先决任务的任务id,一下切割出task_id
resource_id = int(args["image"].split("-")[1])
if "generated image" in results[resource_id]["inference result"]:
args["image"] = results[resource_id]["inference result"]["generated image"]
if "audio" in args and "<GENERATED>-" in args["audio"]:
resource_id = int(args["audio"].split("-")[1])
if "generated audio" in results[resource_id]["inference result"]:
args["audio"] = results[resource_id]["inference result"]["generated audio"]
if "text" in args and "<GENERATED>-" in args["text"]:
resource_id = int(args["text"].split("-")[1])
if "generated text" in results[resource_id]["inference result"]:
args["text"] = results[resource_id]["inference result"]["generated text"]
text = image = audio = None
#处理先决任务生成的text,image,audio
for dep_task in dep_tasks:
#从结果中检测先决任务的生成文本
if "generated text" in dep_task["inference result"]:
#先决任务生成的generated text结果赋值给text
text = dep_task["inference result"]["generated text"]
logger.debug("Detect the generated text of dependency task (from results):" + text)
#如果没有生成,在任务参数中检测先决任务的文本
elif "text" in dep_task["task"]["args"]:
text = dep_task["task"]["args"]["text"]
logger.debug("Detect the text of dependency task (from args): " + text)
#从结果中检测先决任务的生成image
if "generated image" in dep_task["inference result"]:
image = dep_task["inference result"]["generated image"]
logger.debug("Detect the generated image of dependency task (from results): " + image)
#如果没有生成,在任务参数中检测先决任务的image
elif "image" in dep_task["task"]["args"]:
image = dep_task["task"]["args"]["image"]
logger.debug("Detect the image of dependency task (from args): " + image)
#从结果中检测先决任务的生成audio
if "generated audio" in dep_task["inference result"]:
audio = dep_task["inference result"]["generated audio"]
logger.debug("Detect the generated audio of dependency task (from results): " + audio)
#如果没有生成,在任务参数中检测先决任务的audio
elif "audio" in dep_task["task"]["args"]:
audio = dep_task["task"]["args"]["audio"]
logger.debug("Detect the audio of dependency task (from args): " + audio)
#把生成的image、audio、text赋值给任务参数中对应字段的值
if "image" in args and "<GENERATED>" in args["image"]:
if image:
args["image"] = image
if "audio" in args and "<GENERATED>" in args["audio"]:
if audio:
args["audio"] = audio
if "text" in args and "<GENERATED>" in args["text"]:
if text:
args["text"] = text
for resource in ["image", "audio"]:
if resource in args and not args[resource].startswith("public/") and len(args[resource]) > 0 and not args[resource].startswith("http"):
args[resource] = f"public/{args[resource]}"
#如果task字段的值包含"-text-to-image" 并且 args中没有"text",需要生成text
if "-text-to-image" in command['task'] and "text" not in args:
logger.debug("control-text-to-image task, but text is empty, so we use control-generation instead.")
control = task.split("-")[0]
#给task任务做分类
if control == "seg":
task = "image-segmentation" #图像分割
command['task'] = task
elif control == "depth":
task = "depth-estimation" #深度估计
command['task'] = task
else: #其他
task = f"{control}-control"
command["args"] = args
logger.debug(f"parsed task: {command}")
#根据模型的任务类型筛选模型,以"-text-to-image"结尾的任务
if task.endswith("-text-to-image"):
control = task.split("-")[0]
best_model_id = f"lllyasviel/sd-controlnet-{control}"
hosted_on = "local"
reason = "ControlNet is the best model for this task."
choose = {"id": best_model_id, "reason": reason}
logger.debug(f"chosen model: {choose}")
#以"-control"结尾的任务
elif task.endswith("-control"):
best_model_id = task
hosted_on = "local"
reason = "ControlNet tools"
choose = {"id": best_model_id, "reason": reason}
logger.debug(f"chosen model: {choose}")
#其他任务
else:
if task not in MODELS_MAP:
logger.warning(f"no available models on {task} task.")
record_case(success=False, **{"input": input, "task": command, "reason": f"task not support: {command['task']}", "op":"message"})
inference_result = {"error": f"{command['task']} not found in available tasks."}
results[id] = colloct_result(command, choose, inference_result)
return False
#候选人
candidates = MODELS_MAP[task][:10]
all_avaliable_models = get_avaliable_models(candidates)
#所有可用的local和huggingface的模型列表
all_avaliable_model_ids = all_avaliable_models["local"] + all_avaliable_models["huggingface"]
logger.debug(f"avaliable models on {command['task']}: {all_avaliable_models}")
#可用的模型列表为空,记录日志,返回结果
if len(all_avaliable_model_ids) == 0:
logger.warning(f"no available models on {command['task']}")
record_case(success=False, **{"input": input, "task": command, "reason": f"no available models: {command['task']}", "op":"message"})
inference_result = {"error": f"no available models on {command['task']} task."}
results[id] = colloct_result(command, "", inference_result)
return False
#可用的模型列表里只有一个模型
if len(all_avaliable_model_ids) == 1:
best_model_id = all_avaliable_model_ids[0]
hosted_on = "unknown"
reason = "Only one model available."
choose = {"id": best_model_id, "reason": reason}
logger.debug(f"chosen model: {choose}")
#可用的模型列表里只有多个模型
else:
cand_models_info = [
{
"id": model["id"],
"inference endpoint": all_avaliable_models.get(
"local" if model["id"] in all_avaliable_models["local"] else "huggingface"
),
"likes": model.get("likes"),
"description": model.get("description", "")[:100],
"language": model.get("language"),
"tags": model.get("tags"),
}
for model in candidates
if model["id"] in all_avaliable_model_ids
]
#选择模型
choose_str = choose_model(input, command, cand_models_info)
logger.debug(f"chosen model: {choose_str}")
try:
choose = json.loads(choose_str)
reason = choose["reason"]
best_model_id = choose["id"]
hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
except Exception as e:
logger.warning(f"the response [ {choose_str} ] is not a valid JSON, try to find the model id and reason in the response.")
choose_str = find_json(choose_str)
best_model_id, reason, choose = get_id_reason(choose_str)
hosted_on = "local" if best_model_id in all_avaliable_models["local"] else "huggingface"
inference_result = model_inference(best_model_id, args, hosted_on, command['task'])
#如果推理结果中包含错误error,记录results和日志,返回结果False
if "error" in inference_result:
logger.warning(f"Inference error: {inference_result['error']}")
record_case(success=False, **{"input": input, "task": command, "reason": f"inference error: {inference_result['error']}", "op":"message"})
results[id] = colloct_result(command, choose, inference_result)
return False
#否则,results,返回结果True
results[id] = colloct_result(command, choose, inference_result)
return True
3、run_task函数调用record_case 函数:将记录写入log文件,成功写入log_success.jsonl,失败写入log_fail.jsonl
def record_case(success, **args):
if success:
f = open("log_success.jsonl", "a")
else:
f = open("log_fail.jsonl", "a")
log = args
f.write(json.dumps(log) + "\n")
f.close()
4、run_task函数调用 colloct_result,收集结果
def colloct_result(command, choose, inference_result):
result = {"task": command}
result["inference result"] = inference_result
result["choose model result"] = choose
logger.debug(f"inference result: {inference_result}")
return result
5、run_task函数调用get_avaliable_models:获取本地或者远端可用的模型
def get_avaliable_models(candidates, topk=5):
all_available_models = {"local": [], "huggingface": []}
processes = []
result_queue = multiprocessing.Queue() #创建一个结果队列
#循环遍历所有模型
for candidate in candidates:
model_id = candidate["id"]
#处理远端模型
if inference_mode != "local":
huggingfaceStatusUrl = f"https://api-inference.huggingface.co/status/{model_id}"
process = multiprocessing.Process(target=get_model_status, args=(model_id, huggingfaceStatusUrl, HUGGINGFACE_HEADERS, result_queue))
processes.append(process)
process.start()
#处理本地模型
if inference_mode != "huggingface":
localStatusUrl = f"{Model_Server}/status/{model_id}"
process = multiprocessing.Process(target=get_model_status, args=(model_id, localStatusUrl, {}, result_queue))
processes.append(process)
process.start()
result_count = len(processes)
while result_count:
model_id, status, endpoint_type = result_queue.get()
if status and model_id not in all_available_models:
all_available_models[endpoint_type].append(model_id)
if len(all_available_models["local"] + all_available_models["huggingface"]) >= topk:
break
result_count -= 1
#等待所有子进程完成之后,在继续执行主进程
for process in processes:
process.join()
return all_available_models
6、 get_avaliable_models函数调用 get_model_status :获取本地或者huggingface远端服务的状态
def get_model_status(model_id, url, headers, queue):
endpoint_type = "huggingface" if "huggingface" in url else "local"
if "huggingface" in url: #请求远程服务
r = requests.get(url, headers=headers, proxies=PROXY)
else: #本地
r = requests.get(url)
#请求返回状态码是200的,加入队列
if r.status_code == 200 and "loaded" in r.json() and r.json()["loaded"]:
queue.put((model_id, True, endpoint_type))
else:
queue.put((model_id, False, None))
7、run_task函数调用get_id_reason:获取选择模型的id和reason两个字段
def get_id_reason(choose_str):
reason = field_extract(choose_str, "reason")
id = field_extract(choose_str, "id")
choose = {"id": id, "reason": reason}
return id.strip(), reason.strip(), choose
调用field_extract函数: 试用正则表达式完成字符串匹配
def field_extract(s, field):
#用正则表达式匹配
try:
field_rep = re.compile(f'{field}.*?:.*?"(.*?)"', re.IGNORECASE)
extracted = field_rep.search(s).group(1).replace("\"", "\'")
except:
field_rep = re.compile(f'{field}:\ *"(.*?)"', re.IGNORECASE)
extracted = field_rep.search(s).group(1).replace("\"", "\'")
return extracted
8、run_task函数调用colloct_result
def colloct_result(command, choose, inference_result):
result = {"task": command}
result["inference result"] = inference_result
result["choose model result"] = choose
logger.debug(f"inference result: {inference_result}")
return result
未完待续