【Datawhale AI 夏令营第三期 Task 1 学习笔记】
Task 1 任务安排:
1. 跑通给的BaseLine,然后提交结果文件,获取第一个分数。
2. 听助教直播,了解赛事要求以及一些细节。
3. 听助教讲解BaseLine代码,尝试理解代码逻辑。
前两点比较简单,没啥好说的。这里就着重说一下第三点。
着重讲讲baseline中重点方法函数。
def process_datas(datas,MODEL_NAME):
results = []
with ThreadPoolExecutor(max_workers=16) as executor:
future_data = {}
lasttask = ''
lastmark = 0
lens = 0
for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
problem = data['problem']
for id,question in enumerate(data['questions']):
prompt = get_prompt(problem,
question['question'],
question['options'],
)
future = executor.submit(api_retry, MODEL_NAME, prompt)
future_data[future] = (data,id)
time.sleep(0.6) # 控制每0.5秒提交一个任务
lens += 1
for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
# print('data',data)
data = future_data[future][0]
problem_id = future_data[future][1]
try:
res = future.result()
extract_response = extract(res)
# print('res',extract_response)
data['questions'][problem_id]['answer'] = extract_response
results.append(data)
# print('data',data)
except Exception as e:
logger.error(f"Failed to process text: {data}. Error: {e}")
return results
这段代码定义了一个名为 process_datas
的函数,该函数用于处理一系列数据,并利用多线程执行异步任务。下面是代码的详细解释:
函数签名
def process_datas(datas, MODEL_NAME):
-
参数:
datas
: 一个列表,其中每个元素都是一个字典,包含问题和相关问题的选项等信息。MODEL_NAME
: 一个字符串,表示要使用的模型名称。
-
返回值:
- 一个列表,包含了处理后的数据。
函数内部逻辑
-
初始化:
- 创建一个空列表
results
用于存储处理后的数据。 - 创建一个
ThreadPoolExecutor
实例,指定最大工作线程数为 16。 - 初始化一个字典
future_data
用于存储提交的任务和对应的原始数据。 - 初始化变量
lasttask
和lastmark
,它们在这里没有被使用。 - 初始化变量
lens
用于记录提交的任务总数。
- 创建一个空列表
-
提交任务:
- 使用
tqdm
进度条显示提交任务的过程。 - 遍历输入列表
datas
中的每个数据项。 - 对于每个数据项中的每个问题
question
:- 构造一个提示(prompt)字符串,该字符串包含了问题和选项。
- 使用
executor.submit
提交一个任务,该任务调用api_retry
函数,并将MODEL_NAME
和prompt
作为参数传递。 - 将任务的
Future
对象和对应的数据项存储在future_data
字典中。 - 每提交一个任务后,暂停 0.6 秒以控制任务提交的速度。
- 增加任务计数器
lens
。
- 使用
-
处理完成的任务:
- 使用
as_completed
函数获取已完成的任务,并使用tqdm
显示进度。 - 遍历已完成的任务,并从中获取任务的原始数据项。
- 尝试获取任务的结果:
- 调用
future.result()
获取任务的返回值。 - 调用
extract
函数处理返回值,并将结果存储回原始数据项的相应字段。 - 将处理后的数据项追加到结果列表
results
中。
- 调用
- 如果处理过程中出现异常,记录错误信息。
- 使用
-
返回结果:
- 最终返回包含处理后数据的列表
results
。
- 最终返回包含处理后数据的列表
代码分析
-
多线程处理:
- 使用
ThreadPoolExecutor
来并发执行任务,这提高了处理效率。 - 每个任务独立提交,通过
submit
方法异步执行。
- 使用
-
错误处理:
- 通过
try-except
结构捕获并记录处理过程中可能出现的异常。
- 通过
-
进度跟踪:
- 使用
tqdm
库显示任务提交和处理的进度条,方便监控进程。
- 使用
注意事项
-
依赖项:
- 这段代码依赖于
tqdm
和concurrent.futures
库,确保这些库已安装。 api_retry
和get_prompt
函数需要在外部定义。logger
应该是预先配置的日志记录器。
- 这段代码依赖于
-
extract
函数:- 该函数用于从 API 返回的结果中提取答案。需要确保
extract
函数已经被定义。
- 该函数用于从 API 返回的结果中提取答案。需要确保
-
任务调度:
- 代码中每提交一个任务后暂停 0.6 秒,这有助于避免对 API 的频繁请求。
示例使用
import concurrent.futures
from tqdm import tqdm
import time
import logging
# 假设的 api_retry 函数
def api_retry(model_name, prompt):
# 模拟 API 调用
time.sleep(1) # 模拟延迟
return "答案是:C"
# 假设的 get_prompt 函数
def get_prompt(problem, question, options):
# 构造提示字符串
return f"{problem}\n{question}\n{options}"
# 日志配置
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)
# 示例数据
datas = [
{
"problem": "这是一个问题",
"questions": [
{"question": "选择题1", "options": "A B C D"},
{"question": "选择题2", "options": "E F G H"}
]
},
{
"problem": "这是另一个问题",
"questions": [
{"question": "选择题3", "options": "I J K L"},
{"question": "选择题4", "options": "M N O P"}
]
}
]
# 调用 process_datas 函数
results = process_datas(datas, "model_name")
# 打印结果
for result in results:
print(result)
这段示例代码模拟了 api_retry
和 get_prompt
函数的功能,并展示了如何调用 process_datas
函数。