Datawhale AI 夏令营-逻辑推理baselin01代码解析

Datawhale AI 夏令营-从零入门AI逻辑推理

1.赛题介绍

http://competition.sais.com.cn/competitionDetail/532231/competitionData

2.数据集介绍

初赛数据集为逻辑推理数据,其中训练集中包含500条训练数据,测试集中包含500条测试数据。每个问题包括若干子问题,每个子问题为单项选择题,选项不定(最多5个)。目标是为每个子问题选择一个正确答案。推理答案基于闭世界假设(closed-world assumption),即未观测事实或者无法推断的事实为假。
具体的,每条训练数据包含 content, questions字段,其中content是题干,questions为具体的子问题。questions是一个子问题列表,每个子问题包括optionsanswer字段,其中options是一个列表,包含具体的选项,按照ABCDE顺序排列,answer是标准答案。
数据集格式如下:

  • round1_train_data.jsonl : 每一行代表一条反应
{'id': 'round_train_data_001',
'problem': '有一个计算阶乘的递归程序。该程序根据给定的数值计算其阶乘。以下是其工作原理:\n\n当数字是0时,阶乘是1。\n对于任何大于0的数字,其阶乘是该数字乘以其前一个数字的阶乘。\n根据上述规则,回答以下选择题:',
 'questions': [{'question': '选择题 1:\n3的阶乘是多少?\n',
   				   'options': ('3', '6', '9', '12'),
   					'answer': 'B'},
  				  {'question': '选择题 2:\n8的阶乘是多少?\n',
   					'options': ('5040', '40320', '362880', '100000'),
   					'answer': 'B'},
  				  {'question': '选择题 3:\n4的阶乘是多少?\n',
   					'options': ('16', '20', '24', '28'),
   					'answer': 'C'},
  				  {'question': '选择题 4:\n3的阶乘是9吗?\n', 
  				   'options': ('是', '否'), 
  				   'answer': 'B'}]
}

测试集 round1_test_data.jsonl 不包含answer字段。

提交格式:

{'id': 'round1_test_data_000',
 'questions': [{'answer': 'A'}, {'answer': 'D'}, ...], # 顺序与子问题对应
}

3.baseline代码解析

3.1.环境配置

# scipy:科学计算库,包含许多有用的数学、科学和工程函数。
# openai:OpenAI的API库,用于访问OpenAI的GPT模型等服务。
# tiktoken:OpenAI提供的一个库,用于对文本进行分词和编码。
# retry:一个用于实现函数重试机制的库,当函数遇到可恢复的错误时,可以自动重试。
# dashscope:对主流的AI大模型进行封装,提供API接口。
# loguru:一个用于日志记录的库。
!pip install scipy openai tiktoken retry dashscope loguru
#导入相关库
from multiprocessing import Process, Manager
import json
import os
from pprint import pprint
import re
from tqdm import tqdm
import random

import uuid
import openai
import tiktoken
import json
import numpy as np
import requests
from retry import retry
from scipy import sparse
#from rank_bm25 import BM25Okapi
#import jieba
from http import HTTPStatus
import dashscope

from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
import json
import time
from tqdm import tqdm

logger.remove()  # 移除默认的控制台输出
logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")# 添加一个新的日志输出,以日期命名,每天轮换一次,保留 10 天,压缩格式为 zip

MODEL_NAME = 'qwen1.5-1.8b-chat'  # 截止到7.26日下午15:00,该模型在灵积平台限时免费

Qwen是阿里云推出的一系列基于Transformer的大型语言模型,在大量数据(包括网页文本、书籍、代码等)进行了预训练。

3.2. API调用函数

# 填入申请的的API key
dashscope.api_key="sk-"
'''
该函数用于调用一个名为 dashscope.Generation 的 API 来生成文本:具体如下
1.构造请求消息.函数首先将用户的查询封装成一个字典列表,其中每个字典表示一个消息。消息的角色为用户,内容为用户的查询。
2.调用 dashscope API。函数使用 dashscope.Generation.call 方法,向指定的模型(MODEL_NAME)发送消息,并指定结果格式为消息格式(result_format='message')。
3.检查响应状态
'''

def call_qwen_api(MODEL_NAME, query):
    # 这里采用dashscope的api调用模型推理,通过http传输的json封装返回结果
    messages = [
        {'role': 'user', 'content': query}]
    
    response = dashscope.Generation.call(
        MODEL_NAME,
        messages=messages,
        result_format='message',  # set the result is message format.
    )
    # 如果状态码为 HTTPStatus.OK,则提取并返回生成的文本内容。
    if response.status_code == HTTPStatus.OK:
        # print(response) # 调试时可以打印响应内容
        return response['output']['choices'][0]['message']['content']
    # 如果状态码不是 HTTPStatus.OK,请求失败,打印请求 ID、状态码、错误码和错误消息
    else:
        print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
            response.request_id, response.status_code,
            response.code, response.message
        ))
        raise Exception() # 抛出异常以通知调用者请求失败
'''
这个函数 api_retry 用于在调用 call_qwen_api 函数时实现重试机制,以处理可能的网络或其他临时性问题。具体功能如下:
1.设置最大重试次数和重试间隔:定义最大重试次数 max_retries 为 5 次,重试间隔 retry_delay 为 60 秒。
2.初始化尝试次数:初始化尝试次数 attempts 为 0。
3.循环重试:使用 while 循环,当尝试次数小于最大重试次数时,执行以下步骤:
	再尝试调用 call_qwen_api 函数。
	如果调用成功,函数立即返回结果。
	如果调用失败,捕获异常,增加尝试次数,并记录警告日志,提示第几次尝试失败,并将在指定时间后重试。
	等待指定的重试间隔时间后,继续下一次尝试。
4.记录最终失败:如果已达到最大重试次数,记录错误日志,提示所有重试均已失败,并显示错误信息,然后抛出异常。
'''

def api_retry(MODEL_NAME, query):
    max_retries = 5  #  最大重试次数设为5次 
    retry_delay = 60  # 重试间隔时间为60秒
    attempts = 0 # 初始化尝试次数为0
    while attempts < max_retries: # 当尝试次数小于最大重试次数时
        try:
            return call_qwen_api(MODEL_NAME, query) # 调用 call_qwen_api 函数并返回结果
        except Exception as e: # 如果调用过程中出现异常
            attempts += 1   # 尝试次数加1
            if attempts < max_retries:  # 如果尝试次数小于最大重试次数
                logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")  # 记录警告日志,提示第几次尝试失败,并将在指定时间后重试
                time.sleep(retry_delay) # 等待指定的重试间隔时间
            else: # 如果已达到最大重试次数
                logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
                 # 记录错误日志,提示所有重试均已失败,并显示错误信息
                raise  # 抛出异常

3.3. prompt工程

赛题提供了184条真实场景的群聊对话数据,包括训练数据129条,测试数据55条。字段提取的难易程度分为1、2、3三种,具体得分规

'''
这个函数 get_prompt 用于生成逻辑推理题目的提示(prompt),该提示可以用于逻辑推理模型的输入。具体功能如下:
1.格式化选项:将传入的选项列表(options)格式化为字符串,每个选项占一行,并添加字母序号(A, B, C, ...)。
	使用列表推导式和 enumerate 函数遍历选项列表。
	每个选项前添加相应的字母序号,并通过 '\n'.join 将所有选项组合成一个多行字符串。
2.构建提示模板:使用 f-string 构建一个提示字符串,包含以下部分:
	提示用户是一个逻辑推理专家,擅长解决逻辑推理问题。
	提示题目形式为单项选择题,并且所有问题都遵循闭世界假设。
	要求逐步分析问题,并在最后一行输出答案,格式为“答案是:A”。
	包含题目描述(problem)、问题内容(question)和选项(options)。
3.返回提示:返回生成的提示字符串,用于后续逻辑推理模型的输入。
'''

# 这里定义了一个用于生成推理题目prompt的函数
def get_prompt(problem, question, options):
    # 将选项格式化为每行一个选项,并加上字母序号(A, B, C, ...)
    options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))

    # 构建提示模板
    prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:

### 题目:
{problem}

### 问题:
{question}
{options}
"""
    # print(prompt)  # 调试时可以打印生成的提示内容
    return prompt  # 返回生成的提示内容

抽取函数

'''
这个函数 extract 用于从给定的输入文本中抽取答案。具体功能如下:
1.定义正则表达式模式:使用 re.compile 方法定义一个正则表达式模式 ans_pattern,用于匹配“答案是:”后面的一个字符。r"答案是:(.)" 中的 . 表示任意单个字符,括号用于捕获匹配的字符,re.S 使 . 可以匹配包括换行符在内的任意字符。
2.查找匹配项:使用 ans_pattern.findall 方法在输入文本 input_text 中查找所有符合模式的匹配项,结果存储在列表 problems 中。
3.检查匹配结果:
	如果没有找到任何匹配项(即 problems 为空),返回默认答案 'A'。
	如果找到了匹配项,返回列表中的第一个匹配答案。
'''

# 这里定义了一个用于从输入文本中抽取答案的函数
def extract(input_text):
    ans_pattern = re.compile(r"答案是:(.)", re.S)  # 定义正则表达式模式,用于匹配“答案是:”后面的一个字符(答案选项)

    problems = ans_pattern.findall(input_text)  # 使用 findall 方法在输入文本中查找所有符合模式的匹配项
    # print(problems)  # 调试时可以打印匹配的结果
    if(problems == ''):  # 如果没有找到匹配的答案
        return 'A'  # 默认返回答案 'A'
    return problems[0]  # 返回找到的第一个匹

3.4.数据处理

def process_datas(datas, MODEL_NAME):
    results = []  # 用于存储处理后的结果
    with ThreadPoolExecutor(max_workers=16) as executor:  # 创建一个最大并发数为16的线程池
        future_data = {}  # 用于存储每个 future 对应的数据和问题 ID
        lasttask = ''  # 初始化变量
        lastmark = 0  # 初始化变量
        lens = 0  # 记录总任务数

        # 提交任务
        for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):  # 迭代数据集,并显示进度条
            problem = data['problem']  # 获取问题描述
            for id, question in enumerate(data['questions']):  # 迭代问题列表
                prompt = get_prompt(problem, 
                                    question['question'], 
                                    question['options'],
                                    )  # 生成提示

                future = executor.submit(api_retry, MODEL_NAME, prompt)  # 提交任务到线程池
                future_data[future] = (data, id)  # 存储 future 对应的数据和问题 ID
                time.sleep(0.6)  # 控制每0.6秒提交一个任务
                lens += 1  # 增加任务数

        # 处理任务
        for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):  # 迭代已完成的任务,并显示进度条
            data = future_data[future][0]  # 获取数据
            problem_id = future_data[future][1]  # 获取问题 ID
            try:
                res = future.result()  # 获取任务结果
                extract_response = extract(res)  # 从结果中抽取答案
                data['questions'][problem_id]['answer'] = extract_response  # 将答案添加到原始数据
                results.append(data)  # 将处理后的数据添加到结果列表
            except Exception as e:  # 捕获异常
                logger.error(f"Failed to process text: {data}. Error: {e}")  # 记录错误日志

    return results  # 返回处理后的结果

3.5.主函数

'''
读取输入文件中的数据,将其解析为 JSON 对象,并调用 `process_datas` 函数进行处理,最后返回处理后的数据集
'''

def main(ifn, ofn):
    if os.path.exists(ofn):
        pass  # 如果输出文件存在,则跳过处理
    data = []
    # 按行读取输入文件中的数据
    with open(ifn) as reader:
        for line in reader:
            sample = json.loads(line)  # 将每一行数据解析为 JSON 对象
            data.append(sample)  # 将解析后的对象添加到数据列表中
    datas = data  # 将数据列表赋值给 datas 变量
    # print(data)  # 调试时可以打印数据列表
    # 均匀地分成多个数据集
    return_list = process_datas(datas, MODEL_NAME)  # 调用 process_datas 函数处理数据
    print(len(return_list))  # 打印处理后的数据集长度
    print("All tasks finished!")  # 打印所有任务完成的消息
    return return_list  # 返回处理后的数据集

3.6.模型评估

#evaluate函数用于评估模型的预测结果
def evaluate(ofn):
    data = []
    with open(ofn) as reader:
        for line in reader:
            sample = json.loads(line)  # 将每一行数据解析为 JSON 对象
            data.append(sample)  # 将解析后的对象添加到数据列表中

    pse = 0  # 初始化未处理问题计数器
    cnt = 0  # 初始化正确答案计数器
    tot = 0  # 初始化总问题数计数器
    for task in data:
        for question in task['questions']:
            if MODEL_NAME in question:  # 如果问题中包含模型的预测答案
                tot += 1  # 增加总问题数
                cnt += question[MODEL_NAME] == question['answer']  # 如果模型预测答案与正确答案相同,增加正确答案计数器
            else:
                pse += 1  # 如果问题中不包含模型的预测答案,增加未处理问题计数器

    print(cnt, tot, cnt/tot, pse)  # 打印正确答案数、总问题数、正确率和未处理问题数

3.7.测试

if __name__ == '__main__':  # 如果这个脚本是作为主程序运行
    # 测试 extract 函数
    a = extract("""根据欧几里得算法,逐步解析计算两个数6和7的最大公约数(gcd)的步骤如下:

1. 判断6和7是否相等:不相等。
2. 判断6和7大小关系,7 > 6,所以用更大的数7减去较小的数6得到结果1。
3. 现在计算6和1的最大公约数。
4. 6 > 1,根据算法用更大的数6减去较小的数1得到结果5。
5. 再计算5和1的最大公约数。
6. 5 > 1,用5减去1得到结果4。
7. 再计算4和1的最大公约数。
8. 4 > 1,用4减去1得到结果3。
9. 再计算3和1的最大公约数。
10. 3 > 1,用3减去1得到结果2。
11. 再计算2和1的最大公约数。
12. 2 > 1,用2减去1得到结果1。
13. 最后计算1和1的最大公约数,两数相等,gcd即为这两个数,也就是1。

因此,6和7的最大公约数是1。

答案是:C.""")

    print(a)  # 打印从输入文本中提取的答案

    # 调用 main 函数处理数据
    return_list = main('round1_test_data.jsonl', 'upload.jsonl')

Submitting tasks:   0%|          | 2/500 [00:02<10:52,  1.31s/it]
{"status_code": 200, "request_id": "97b8fbae-528c-9ecd-9370-0f2c167843b0", "code": "", "message": "", "output": {"text": null, "finish_reason": null, "choices": [{"finish_reason": "stop", "message": {"role": "assistant", "content": "分析已知信息:\n- 我们知道鸡肉和苹果是食物。\n- Bill存活并吃了花生。\n- John吃所有食物,因此他吃了鸡肉、苹果和花生。\n- Sue吃所有Bill吃的食物,所以她吃了鸡肉、苹果和花生。\n- John喜欢所有食物,并且已经吃了这些食物。\n\n根据以上信息,我们可以得出结论:\n- Bill吃了花生。\n- Sue也吃了花生,因为她吃所有Bill吃的食物。\n- John吃了花生,因为他吃所有食物。\n- 所以,至少Bill、Sue和John都吃了花生。\n\n答案是:D. None of the above"}}]}, "usage": {"input_tokens": 210, "output_tokens": 128, "total_tokens": 338}}

3.8.纠错与结果文件生成

def has_complete_answer(questions):
    # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键
    for question in questions:
        if 'answer' not in question:
            return False
    return True

这个函数的作用是过滤一组问题数据,以确保每个问题在结果列表中只出现一次,并且如果同一个问题有多个条目,优先保留包含完整答案的条目:

def filter_problems(data):
    result = []
    problem_set = set()
    
    for item in data:
        # 获取当前条目的 'problem' 字段
        problem = item['problem']
        if problem in problem_set:
            # 如果问题已经存在于 problem_set 中,找到对应的字典
            for existing_item in result:
                if existing_item['problem'] == problem:
                    # 如果当前字典有完整答案,替换已存在的字典
                    if has_complete_answer(item['questions']):
                        existing_item['questions'] = item['questions']
                        existing_item['id'] = item['id']
                    break
        else:
            # 如果问题尚未在 problem_set 中出现
            # 如果当前字典有完整答案,添加到结果列表
            if has_complete_answer(item['questions']):
                result.append(item)
                # 将问题添加到 problem_set 中
                problem_set.add(problem)

    return result
return_list
return_list = filter_problems(return_list)
# 对过滤后return_list 进行排序,排序的关键是字典中 id 字段的后三位数字
sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:])) 

print(sorted_data)
[{'problem': '有一群人和一些食物类型。下列是关于这些个体和食物的已知信息:\n\n1. 鸡肉是一种食物。\n2. 苹果是一种食物。\n3. 如果X吃了Y,且X活着,则Y是一种食物。\n4. Bill存活。\n5. Bill吃了花生。\n6. John吃所有食物。\n7. Sue吃所有Bill吃的食物。\n8. John喜欢所有食物。\n\n根据以上信息,回答以下选择题:', 
'questions': [{'question': '选择题 1:\n谁喜欢吃花生?', 'options': ['Bill', 'Sue', 'John', 'None of the above'], 'answer': 'D'}], 
'id': 'round1_test_data_000'
}, 
{'problem': '设有两个列表操作,第一个列表中包含若干元素,并且可以将第二个列表的元素追加到第一个列表的末尾形成一个新的列表。根据这个操作,请回答以下问题:',
'questions': [{'question': '选择题 1:\n当第一个列表是[a, b, c],第二个列表是[d, e]时,新的列表是什么?', 'options': ['[a, b, c, d]', '[a, d, e, b, c]', '[a, b, c, d, e]', '[d, e, a, b, c]'],
'answer': 'C'}, {'question': '选择题 2:\n如果新列表是[a, b, c, d, e],而第二个列表是[d, e],那么第一个列表是什么?','options': ['[a]', '[a, b]', '[a, b, c]', '[b, c]'],'answer': 'C'},
def find_missing_ids(dict_list):
    # 提取所有序号
    extracted_ids = {int(d['id'][-3:]) for d in dict_list}
    
    # 创建0-500的序号集合
    all_ids = set(range(500))
    
    # 找出缺失的序号
    missing_ids = all_ids - extracted_ids
    
    return sorted(missing_ids)

# 示例字典列表
dict_list = sorted_data

# 找出缺失的序号
missing_ids = find_missing_ids(dict_list)
print("缺失的序号:", missing_ids)  
# 缺失的序号: [33, 67, 93, 95, 183, 216, 229, 267, 309, 361]

len(missing_ids)
# 初始化一个空列表用于存储数据
data = []

# 打开 round1_test_data.jsonl 文件并逐行读取
with open('round1_test_data.jsonl') as reader:
    for id, line in enumerate(reader):
        # 检查当前行的 id 是否在 missing_ids 列表中
        if id in missing_ids:
            # 将当前行的 JSON 字符串解析为字典
            sample = json.loads(line)
            # 将 sample 中的每个问题的答案设置为 'A'
            for question in sample['questions']:
                question['answer'] = 'A'
            # 将处理后的 sample 添加到 sorted_data 列表中
            sorted_data.append(sample)

# 按照 id 的后三位对 sorted_data 列表进行排序
sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))

# 打开 upload.jsonl 文件,并将处理后的数据逐行写入文件中
with open('upload.jsonl', 'w') as writer:
    for sample in sorted_data:
        # 将每个 sample 字典转换为 JSON 字符串并写入文件
        writer.write(json.dumps(sample, ensure_ascii=False))
        # 写入换行符以分隔不同的样本
        writer.write('\n')
        # 将处理后的 sample 添加到 sorted_data 列表中
        sorted_data.append(sample)
        # 按照 id 的后三位对 sorted_data 列表进行排序
        sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))

# 打开 upload.jsonl 文件,并将处理后的数据逐行写入文件中
with open('upload.jsonl', 'w') as writer:
    for sample in sorted_data:
        # 将每个 sample 字典转换为 JSON 字符串并写入文件
        writer.write(json.dumps(sample, ensure_ascii=False))
        # 写入换行符以分隔不同的样本
        writer.write('\n')
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值