​【Datawhale AI 夏令营第三期 Task 1 学习笔记】

​【Datawhale AI 夏令营第三期 Task 1 学习笔记】

Task 1 任务安排:

1. 跑通给的BaseLine,然后提交结果文件,获取第一个分数。
2. 听助教直播,了解赛事要求以及一些细节。
3. 听助教讲解BaseLine代码,尝试理解代码逻辑。

前两点比较简单,没啥好说的。这里就着重说一下第三点。

着重讲讲baseline中重点方法函数。

def process_datas(datas,MODEL_NAME):
    results = []
    with ThreadPoolExecutor(max_workers=16) as executor:
        future_data = {}
        lasttask = ''
        lastmark = 0
        lens = 0
        for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
            problem = data['problem']
            for id,question in enumerate(data['questions']):
                prompt = get_prompt(problem, 
                                    question['question'], 
                                    question['options'],
                                    )

                future = executor.submit(api_retry, MODEL_NAME, prompt)
                
                future_data[future] = (data,id)
                time.sleep(0.6)  # 控制每0.5秒提交一个任务
                lens += 1
        for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
            # print('data',data)
            data = future_data[future][0]
            problem_id = future_data[future][1]
            try:
                res  = future.result()
                extract_response = extract(res)
                # print('res',extract_response)
                data['questions'][problem_id]['answer'] = extract_response
                results.append(data)
                # print('data',data)
                
            except Exception as e:
                logger.error(f"Failed to process text: {data}. Error: {e}")
    
    return results

这段代码定义了一个名为 process_datas 的函数,该函数用于处理一系列数据,并利用多线程执行异步任务。下面是代码的详细解释:

函数签名

def process_datas(datas, MODEL_NAME):
  • 参数

    • datas: 一个列表,其中每个元素都是一个字典,包含问题和相关问题的选项等信息。
    • MODEL_NAME: 一个字符串,表示要使用的模型名称。
  • 返回值

    • 一个列表,包含了处理后的数据。

函数内部逻辑

  1. 初始化

    • 创建一个空列表 results 用于存储处理后的数据。
    • 创建一个 ThreadPoolExecutor 实例,指定最大工作线程数为 16。
    • 初始化一个字典 future_data 用于存储提交的任务和对应的原始数据。
    • 初始化变量 lasttasklastmark,它们在这里没有被使用。
    • 初始化变量 lens 用于记录提交的任务总数。
  2. 提交任务

    • 使用 tqdm 进度条显示提交任务的过程。
    • 遍历输入列表 datas 中的每个数据项。
    • 对于每个数据项中的每个问题 question
      • 构造一个提示(prompt)字符串,该字符串包含了问题和选项。
      • 使用 executor.submit 提交一个任务,该任务调用 api_retry 函数,并将 MODEL_NAMEprompt 作为参数传递。
      • 将任务的 Future 对象和对应的数据项存储在 future_data 字典中。
      • 每提交一个任务后,暂停 0.6 秒以控制任务提交的速度。
      • 增加任务计数器 lens
  3. 处理完成的任务

    • 使用 as_completed 函数获取已完成的任务,并使用 tqdm 显示进度。
    • 遍历已完成的任务,并从中获取任务的原始数据项。
    • 尝试获取任务的结果:
      • 调用 future.result() 获取任务的返回值。
      • 调用 extract 函数处理返回值,并将结果存储回原始数据项的相应字段。
      • 将处理后的数据项追加到结果列表 results 中。
    • 如果处理过程中出现异常,记录错误信息。
  4. 返回结果

    • 最终返回包含处理后数据的列表 results

代码分析

  1. 多线程处理

    • 使用 ThreadPoolExecutor 来并发执行任务,这提高了处理效率。
    • 每个任务独立提交,通过 submit 方法异步执行。
  2. 错误处理

    • 通过 try-except 结构捕获并记录处理过程中可能出现的异常。
  3. 进度跟踪

    • 使用 tqdm 库显示任务提交和处理的进度条,方便监控进程。

注意事项

  • 依赖项

    • 这段代码依赖于 tqdmconcurrent.futures 库,确保这些库已安装。
    • api_retryget_prompt 函数需要在外部定义。
    • logger 应该是预先配置的日志记录器。
  • extract 函数

    • 该函数用于从 API 返回的结果中提取答案。需要确保 extract 函数已经被定义。
  • 任务调度

    • 代码中每提交一个任务后暂停 0.6 秒,这有助于避免对 API 的频繁请求。

示例使用

import concurrent.futures
from tqdm import tqdm
import time
import logging

# 假设的 api_retry 函数
def api_retry(model_name, prompt):
    # 模拟 API 调用
    time.sleep(1)  # 模拟延迟
    return "答案是:C"

# 假设的 get_prompt 函数
def get_prompt(problem, question, options):
    # 构造提示字符串
    return f"{problem}\n{question}\n{options}"

# 日志配置
logging.basicConfig(level=logging.ERROR)
logger = logging.getLogger(__name__)

# 示例数据
datas = [
    {
        "problem": "这是一个问题",
        "questions": [
            {"question": "选择题1", "options": "A B C D"},
            {"question": "选择题2", "options": "E F G H"}
        ]
    },
    {
        "problem": "这是另一个问题",
        "questions": [
            {"question": "选择题3", "options": "I J K L"},
            {"question": "选择题4", "options": "M N O P"}
        ]
    }
]

# 调用 process_datas 函数
results = process_datas(datas, "model_name")

# 打印结果
for result in results:
    print(result)

这段示例代码模拟了 api_retryget_prompt 函数的功能,并展示了如何调用 process_datas 函数。

  • 14
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值