背景:数据量很大的时候,需要多线程调用接口获取数据但是又不想一次性加载全部的原始数据进列表,可以结合批量加载数据和多线程。
代码实现:
批量加载:
def load_data_in_batches(file_path, batch_size):
if file_path.endswith(".json"):#先读取数据
with open(file_path, "r") as file:
data_list = json.load(file)
else:
data_list = []
with open(file_path, "r") as f:
file_lines = f.read().strip().split("\n")
for line in file_lines:
line = json.loads(line)
data_list.append(line)
for i in range(0, len(data_list), batch_size):
yield data_list[i:i + batch_size]
def prepare_batch_input_data(batch, model, temperature=None): #如果原始数据需要处理一下再发送给接口就在此方法进行 否则不需要此方法
input_data_list = []
for data in batch:
query_content = data["messages"][0]["content"]
input_data = {
"prompt": [{"role": "user", "content": query_content}],
"label": "",
"model": model,
}
if temperature is not None:
input_data["temperature"] = temperature
input_data_list.append(input_data)
return input_data_list
多线程调用接口:
import logging
default_logger = logging.getLogger(__name__)
class SSEClient:
def __init__(
self,
data,
# url=""
auth_token="",
logger=None,
):
if logger is None:
self.logger = default_logger
self.url = url
self.headers = {
"Authorization": auth_token,
"Content-Type": "application/json",
}
self.data = data
self.send_time = 0 # 用于记录请求发送的时间
self.create_time = time.time() # 创建对象的时间
self.first_message_time = 0 # 第一次接收消息的时间
self.first_server_response = 0 # 第一次服务器响应的时间
self.update_logger_context()
def update_logger_context(self):
self.logger = logging.LoggerAdapter(
self.logger, {"session_id": self.session_id, "task_id": self.session_id}
)
def get_result(self, show_details=False):
self.send_time = time.time() # 记录请求发送的时间
# Stream output
response = requests.post(
self.url,
headers=self.headers,
json=self.data,
stream=True,
)
if response.headers.get("Content-Type") == "application/json":
return {"status": "failed", "message": response.json()}
if show_details:
print("headers is ", response.headers)
text = ""
event_type = ""
buffer = ""
final_text = ""
for line in response.iter_lines(decode_unicode=True):
buffer += line + "\n"
if not self.first_message_time: # 当收到第一条消息时记录时间
self.first_message_time = time.time()
self.logger.info(
f"First server message response at {self.first_message_time}"
)
if buffer.endswith("\n\n"):
for field_line in buffer.split("\n"):
if field_line.startswith("event:"):
event_type = field_line[len("event:") :].strip()
if event_type == "finish":
break
elif field_line.startswith("data:"):
value = field_line[len("data:") :].replace("\\n", "\n")
text += value
if show_details:
print(f"Event Type: {event_type}, Text: {text}")
final_text += text
buffer = ""
text = ""
self.message = final_text
# 根据您提供的字典结构构建返回结果
result = {
"response": self.message,
"data": copy.deepcopy(self.data),
"send_time": self.send_time,
"create_time": self.create_time,
"first_message_time": self.first_message_time,
"get_first_message_time": round(
self.first_message_time - self.send_time, 2
),
"duration": round(time.time() - self.send_time, 2),
"status": "success",
}
self.logger.info(f"Final result is {json.dumps(result)}")
return result
def invoke_model(data):
for i in range(0, 1): #如果想一条数据多次调用的话
#print(data)
if "url" in data:
url = data["url"]
result = SSEClient(data, url=url).get_result()
else:
result = SSEClient(data).get_result()
# 后面还可以加一些对数据的处理,这里可以直接定义储存,比如:
with open(save_path,"a") as f:
f.write(result)
f.write("\n")
def parallel_execution(demo_json_list, n_jobs=4):
n_jobs = min(n_jobs, len(demo_json_list))
results = [None] * len(demo_json_list)
print(len(demo_json_list))
with concurrent.futures.ThreadPoolExecutor(max_workers=n_jobs) as executor:
# 使用字典来存储future与其在demo_json_list中的索引
future_to_idx = {
executor.submit(invoke_model, demo): idx
for idx, demo in enumerate(demo_json_list)
}
for feature, feature_idx in tqdm(future_to_idx.items(), desc="Processing"):
try:
result = feature.result(timeout=120)
results[feature_idx] = result
except Exception as e:
print(e)
results[feature_idx] = None
联合使用:
def main(test_data_list, batch_size):
for test_data in test_data_list: #如果有多个数据
file_path = test_data["file_path"]
temperature = test_data.get("temperature")
# Load and process data in batches
for batch in load_data_in_batches(file_path, batch_size):
input_data_list = prepare_batch_input_data(batch, model, temperature)
response_list = parallel_execution(input_data_list, n_jobs=n_jobs)
data_list = {
[{"file_path":"","temperature":0.1},{"file_path":"","temperature":0.2}]
main(data_list,10)