AI逻辑推理-baseline学习&个人拓展学习 #Datawhale AI 夏令营

从零入门AI+逻辑推理 是  Datawhale 2024 年 AI 夏令营第三期的学习活动 的方向之一,基于“第二届世界科学智能大赛逻辑推理赛道:复杂推理能力评估” 赛事的要求,开展了一系列组队学习活动。

作为一个在AI、大模型领域知之甚少的小白,在Datawhale提供的学习资源与交流环境中,收获了许多新知。

当下就本次学习活动中的Task1 Task2的baseline,结合自己的学习、理解和与他人交流所得,记录一些想法。

>>Baseline相关内容

1.

!pip install scipy openai tiktoken retry dashscope loguru

首先执行了pip install命令安装了一些Python的库:

(1)Scipy:

  • 科学计算库,常用于数学、科学和工程领域。
  • 用于优化、线性代数、积分、插值、信号处理等。

eg from GPT:

import scipy
from scipy import linalg
import numpy as np

a = np.array([[1,2], [3,4]])
print(linalg.det(a))

(2)OpenAI:

  • OpenAI 的 API 客户端,用于访问 GPT-3 等模型。
  • 用于生成文本、自动化任务、聊天机器人等。

eg from GPT:

import openai

openai.api_key = 'your-api-key'
response = openai.Completion.create(
  model="text-davinci-003",
  prompt="Translate the following English text to French: 'Hello, how are you?'",
  max_tokens=60
)

print(response.choices[0].text.strip())

(3)Tiktoken:

  • 高效的文本标记化库,常用于自然语言处理

eg from GPT:

import tiktoken

enc = tiktoken.get_encoding("gpt2")
encoded = enc.encode("Hello world!")
print(encoded)

(4)Retry:

  • 用于实现重试机制的库,适合网络请求等可能失败的操作。

(5)Dashscope:

  • 用于数据可视化和构建仪表盘的库。

(6)Loguru:

  • 方便的日志记录库,简化了日志记录配置。

特别记录一下,因为自己最近也在学习Pytorch相关内容,过程中在这行简单的代码中也发现了有时会有不同,比方说下面二者:

!pip install numpy
pip install numpy

 询问GPT后了解到,前者 !pip install 常在 Jupyter Notebook 或其他支持代码块执行的交互式环境中使用,用于在笔记本或交互式环境中直接从代码单元执行 shell 命令。而 pip install 常在命令行或终端中使用,用于在标准 Python 开发环境中安装库和包。

我们这里是在 .ipynb 文件也就是JupyterNotebook类型的文件,所以会运用前者.

关于包管理工具pip的说明(参考了其他同学):Python pip的安装与使用

2.

from multiprocessing import Process, Manager
import json
import os
from pprint import pprint
import re
from tqdm import tqdm
import random

import uuid
import openai
import tiktoken
import json
import numpy as np
import requests
from retry import retry
from scipy import sparse
#from rank_bm25 import BM25Okapi
#import jieba
from http import HTTPStatus
import dashscope


from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
import json
import time
from tqdm import tqdm

logger.remove()  # 移除默认的控制台输出
logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")

MODEL_NAME = 'qwen1.5-1.8b-chat'  # 截止到7.26日下午15:00,该模型在灵积平台限时免费
dashscope.api_key="sk-"

这一部分就是import了各种各样的库,通过DashScope调用了阿里的'qwen1.5-1.8b-chat'模型.

(虽然注释是说截至到7.26免费,不过在这个时间之后使用貌似也没大问题..至少我还没有收到欠费的通知)

并且在最后调用阿里云的模型服务灵积 Dash Scope的api,对模型进行训练微调或者定制化.

(我其实一直不明确Dash Scope到底是一个什么定位的产品,所以这里也去做了一些初步的了解.)

-->Dash Scope的产品简介 基本概念 快速入门

其中,

logger.remove()  # 移除默认的控制台输出
logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")

这段代码配置了 loguru 日志库,将日志记录到文件中,文件名格式为 app_YYYY-MM-DD.log,每天一个日志文件,并保留 10 天,过期的日志文件将被压缩。

Loguru 是一个第三方 Python 日志库,提供了一种简洁的方式记录应用程序的日志。

3.

def api_retry(MODEL_NAME, query):
    max_retries = 5
    retry_delay = 60  # in seconds
    attempts = 0
    while attempts < max_retries:
        try:
            return call_qwen_api(MODEL_NAME, query)
        except Exception as e:
            attempts += 1   
            if attempts < max_retries:
                logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
                raise
def call_qwen_api(MODEL_NAME, query):
    # 这里采用dashscope的api调用模型推理,通过http传输的json封装返回结果
    messages = [
        {'role': 'user', 'content': query}]
    response = dashscope.Generation.call(
        MODEL_NAME,
        messages=messages,
        result_format='message',  # set the result is message format.
    )
    if response.status_code == HTTPStatus.OK:
        # print(response)
        return response['output']['choices'][0]['message']['content']
    else:
        print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
            response.request_id, response.status_code,
            response.code, response.message
        ))
        raise Exception()

这里是对两个函数 api_retry 和 call_qwen_api 的定义.

(1)api_retry(MODEL_NAME, query):

该函数用于调用名为call_qwen_api的API来查询数据。它会尝试最多5次调用API,每次尝试之间有60秒的延迟。如果所有尝试都失败,则会抛出异常。函数中使用了一个循环来实现重试机制,通过增加尝试次数并检查是否达到最大重试次数来决定是否继续重试或抛出异常。在重试期间,会记录警告日志,在所有尝试失败后,会记录错误日志。

api_retry 函数通过提供自动重试机制和详细的日志记录,增强了API调用的鲁棒性和可维护性,特别是在面对不稳定或高负载的网络环境时。

(2)call_qwen_api(MODEL_NAME, query):

这段代码定义了一个名为call_qwen_api的函数,其目的是调用一个基于Dashscope平台的AI模型,通常这种模型可以是语言模型。这里调用的就是上文提及的阿里的通义千问系列的'qwen1.5-1.8b-chat'模型。

4.

def get_prompt(problem, question, options):

    options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))

    prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:

### 题目:
{problem}

### 问题:
{question}
{options}
"""
    # print(prompt)
    return prompt

这里则是定义了prompt的推理模块。

problem: 字符串类型,表示逻辑推理题目的背景信息或情境描述。

question: 字符串类型,表示具体的问题。

options: 列表类型,包含多个选项,每个选项也是一个字符串。

  • options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))

使用列表推导式和字符串连接操作,将options列表中的每个元素转换为带有大写字母标签(A至G)的格式化字符串,并用换行符\n连接起来,形成一个多行字符串,每一行是一个选项。

  • prompt = f"""..."""

使用f-string(格式化字符串文字)构建一个长字符串prompt,其中包含了对AI模型的指导语句、问题的背景信息、具体问题以及选项。提示信息中明确指出AI模型应扮演逻辑推理专家的角色,并遵循闭世界假设原则(即未提及的事实视为不成立),最后要求模型以特定格式(“答案是:A”)输出答案。

5.

def extract(input_text):
    ans_pattern = re.compile(r"答案是:(.)", re.S)

    problems = ans_pattern.findall(input_text)
    # print(problems)
    if(problems == ''):
        return 'A'
    return problems[0]

这里使用extract抽取模获得抽取的结果。

6.

def process_datas(datas,MODEL_NAME):
    results = []
    with ThreadPoolExecutor(max_workers=16) as executor:
        future_data = {}
        lasttask = ''
        lastmark = 0
        lens = 0
        for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
            problem = data['problem']
            for id,question in enumerate(data['questions']):
                prompt = get_prompt(problem, 
                                    question['question'], 
                                    question['options'],
                                    )

                future = executor.submit(api_retry, MODEL_NAME, prompt)
                
                future_data[future] = (data,id)
                time.sleep(0.6)  # 控制每0.5秒提交一个任务
                lens += 1
        for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
            # print('data',data)
            data = future_data[future][0]
            problem_id = future_data[future][1]
            try:
                res  = future.result()
                extract_response = extract(res)
                # print('res',extract_response)
                data['questions'][problem_id]['answer'] = extract_response
                results.append(data)
                # print('data',data)
                
            except Exception as e:
                logger.error(f"Failed to process text: {data}. Error: {e}")
    
    return results

这段代码定义了一个名为process_datas的函数,其核心功能是异步处理一组数据集,利用多线程加速对每个数据条目中问题的处理,通过调用预先定义的api_retry函数来获取AI模型的回答,并进一步解析这些回答以提取答案。

result作为一个空白列表,用于储存处理后的数据结果

  • with ThreadPoolExecutor(max_workers=16) as executor:

使用ThreadPoolExecutor上下文管理器创建一个线程池,指定最大工作线程数为16,用于并发执行任务。

  • future_data = {}
    for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
        ...
        future = executor.submit(api_retry, MODEL_NAME, prompt)
        future_data[future] = (data,id)
        time.sleep(0.6)  # 控制每0.5秒提交一个任务

遍历datas列表中的每个数据条目,为每个问题构建prompt并通过executor.submit提交任务到线程池中。future_data字典用于跟踪每个Future对象及其对应的数据和问题ID。time.sleep(0.6)用于控制任务提交的速率,防止短时间内提交过多任务导致API限流。

  • for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
        ...
        try:
            res  = future.result()
            extract_response = extract(res)
            data['questions'][problem_id]['answer'] = extract_response
            results.append(data)
        except Exception as e:
            logger.error(f"Failed to process text: {data}. Error: {e}")

使用as_completed迭代器监控所有已提交任务的完成状态,一旦某个任务完成,就立即处理其结果。通过future.result()获取任务的返回值,然后使用extract函数解析结果以提取答案。解析后的答案被更新到原始数据中,并将处理后的数据追加到results列表中。如果处理过程中出现异常,将记录错误日志。最后返回处理后的数据结果。

总体来说,process_datas函数通过异步处理和多线程技术显著提高了处理大量数据的效率,同时利用api_retryextract函数分别实现了AI模型调用和结果解析的核心功能,确保了数据处理的完整性和准确性。

7.main函数的定义

def main(ifn, ofn):
    if os.path.exists(ofn):
        pass
    data = []
    # 按行读取数据
    with open(ifn) as reader:
        for line in reader:
            sample = json.loads(line)
            data.append(sample)
    datas = data
    # print(data)
    # 均匀地分成多个数据集
    return_list = process_datas(datas,MODEL_NAME)
    print(len(return_list))
    print("All tasks finished!")
    return return_list

main函数负责从指定的输入文件中读取数据,将每行的json字符串转换为字典,并将这些字典组成一个列表。随后,它调用process_datas函数处理这些数据,传入所需的AI模型名称。处理完成后,函数输出处理结果的数量,并返回处理后的数据列表。

8.

def evaluate(ofn):
    data = []
    with open(ofn) as reader:
        for line in reader:
            sample = json.loads(line)
            data.append(sample)

    pse = 0
    cnt = 0
    tot = 0
    for task in data:
        for question in task['questions']:
            
            if MODEL_NAME in question:
                tot += 1
                cnt += question[MODEL_NAME] == question['answer']
            else:
                pse += 1

    print(cnt, tot, cnt/tot, pse)

evaluate函数从指定的输出文件中读取数据,将每行的json字符串转换为字典,并统计模型预测答案与真实答案的匹配情况。它计算了总问题数量、正确预测的数量以及准确率,并输出这些统计数据。

9.

if __name__ == '__main__':

    a = extract("""根据欧几里得算法,逐步解析计算两个数6和7的最大公约数(gcd)的步骤如下:

1. 判断6和7是否相等:不相等。
2. 判断6和7大小关系,7 > 6,所以用更大的数7减去较小的数6得到结果1。
3. 现在计算6和1的最大公约数。
4. 6 > 1,根据算法用更大的数6减去较小的数1得到结果5。
5. 再计算5和1的最大公约数。
6. 5 > 1,用5减去1得到结果4。
7. 再计算4和1的最大公约数。
8. 4 > 1,用4减去1得到结果3。
9. 再计算3和1的最大公约数。
10. 3 > 1,用3减去1得到结果2。
11. 再计算2和1的最大公约数。
12. 2 > 1,用2减去1得到结果1。
13. 最后计算1和1的最大公约数,两数相等,gcd即为这两个数,也就是1。

因此,6和7的最大公约数是1。

答案是:C.""")

    print(a)
    return_list = main('round1_test_data.jsonl', 'upload.jsonl')

这部分代码首先测试了extract函数的功能,验证其能否正确从文本中抽取答案。接着,通过调用main函数处理了一个实际的数据集,并返回处理后的结果。

10.

def has_complete_answer(questions):
    # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键
    for question in questions:
        if 'answer' not in question:
            return False
    return True

def filter_problems(data):
    result = []
    problem_set = set()

    for item in data:
        # print('处理的item' ,item)
        problem = item['problem']
        if problem in problem_set:
            # 找到已存在的字典
            for existing_item in result:
                if existing_item['problem'] == problem:
                    # 如果当前字典有完整答案,替换已存在的字典
                    if has_complete_answer(item['questions']):
                        existing_item['questions'] = item['questions']
                        existing_item['id'] = item['id']
                    break
        else:
            # 如果当前字典有完整答案,添加到结果列表
            if has_complete_answer(item['questions']):
                result.append(item)
                problem_set.add(problem)

    return result

这里是对另外两个函数 has_complete_answer 和 filter_problems的定义。

  • def has_complete_answer(questions):
        for question in questions:
            if 'answer' not in question:
                return False
        return True

遍历问题列表questions中的每个问题。如果发现任何一个问题字典中没有'answer'键,则返回False。如果所有问题都有'answer'键,则返回True

这个函数主要用于检查一个问题列表是否包含完整答案。

  • def filter_problems(data):
        result = []
        problem_set = set()
    
        for item in data:
            problem = item['problem']
            if problem in problem_set:
                for existing_item in result:
                    if existing_item['problem'] == problem:
                        if has_complete_answer(item['questions']):
                            existing_item['questions'] = item['questions']
                            existing_item['id'] = item['id']
                        break
            else:
                if has_complete_answer(item['questions']):
                    result.append(item)
                    problem_set.add(problem)
    
        return result

初始化一个空列表result用于存放最终结果,以及一个空集合problem_set用于存储已处理的问题。

接下来是一连串地遍历和条件判断:

遍历输入数据data中的每个条目item。如果item中的问题problem已经在problem_set中,则查找result列表中已有的相同问题条目。

item包含完整答案,则更新已有的问题条目中的问题列表和ID。

item中的问题不在problem_set中,则检查item是否包含完整答案。

如果包含,则将item添加到result列表中,并将问题添加到problem_set中。最终返回处理后的结果列表result

这个函数主要用于从数据集中筛选出包含完整答案的问题,并确保每个问题只出现一次,同时更新具有完整答案的问题。

11.

return_list
return_list = filter_problems(return_list)
sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:]))
print(sorted_data)

sorted_data

这段代码首先调用了 filter_problems 处理 return_list ,再使用 sorted 进行排序,并将排序后的数据列表print出来。

12.

def find_missing_ids(dict_list):
    # 提取所有序号
    extracted_ids = {int(d['id'][-3:]) for d in dict_list}
    
    # 创建0-500的序号集合
    all_ids = set(range(500))
    
    # 找出缺失的序号
    missing_ids = all_ids - extracted_ids
    
    return sorted(missing_ids)

# 示例字典列表
dict_list = sorted_data

# 找出缺失的序号
missing_ids = find_missing_ids(dict_list)
print("缺失的序号:", missing_ids)

临近尾声,这里定义了一个 find_missing_ids 函数,用于找出给定字符列表中缺失的符号。

  • extracted_ids = {int(d['id'][-3:]) for d in dict_list}

    使用集合推导式提取dict_list中每个字典d'id'键对应的值的最后三位数字,并转换为整数,形成一个集合extracted_ids

  • all_ids = set(range(500))
    missing_ids = all_ids - extracted_ids

创建一个包含0到499的整数集合all_ids,并通过集合差运算找出all_ids中不存在于extracted_ids中的元素,即缺失的序号集合。

  • return sorted(missing_ids)
    dict_list = sorted_data
    missing_ids = find_missing_ids(dict_list)
    print("缺失的序号:", missing_ids)

调用find_missing_ids函数处理sorted_data,找出缺失的序号。打印缺失的序号集合。

13.

len(missing_ids)

data  = []
with open('round1_test_data.jsonl') as reader:
    for id,line in enumerate(reader):
        if(id in missing_ids):
            sample = json.loads(line)
            for question in sample['questions']:
                question['answer'] = 'A'
            sorted_data.append(sample)
sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))

len(missing_ids)返回缺失的序号总数。

而下方代码打开文件'round1_test_data.jsonl'并逐行读取,对于文件中的每一行,使用enumerate获取行索引id和行内容line,如果行索引idmissing_ids列表中,则将该行的内容转换为Python字典sample。对于sample中的每个问题,将'answer'键设置为'A',并将处理后的sample添加到sorted_data列表中。

最后使用sorted函数对sorted_data进行排序。排序的关键字是每个问题字典中'id'键对应的值的最后三位数字,转换为整数后作为排序依据。

14.

with open('upload.jsonl', 'w') as writer:
    for sample in sorted_data:
        writer.write(json.dumps(sample, ensure_ascii=False))
        writer.write('\n')

使用with语句打开文件'upload.jsonl'以写入模式('w'),遍历sorted_data列表中的每个字典sample。将每个sample字典转换为json格式的字符串,并写入文件,确保非ASCII字符不会被转义(ensure_ascii=False),并在每个json字符串后面写入换行符\n。处理后的数据列表就这样被转换为json格式的字符串,并写入到文件upload.jsonl中。

总结:

借助大模型工具,把baseline读下来之后,对于其中的一些实现细节的认识还有待加深。这次内容中很多产出还是仰仗了大模型的,这份笔记权当之后深入学习或者复习的材料,并且当作鞭策自己进一步学习的停顿。

希望在之后的学习中自己对于代码中不熟悉的知识点或多或少会有一些深入了解。

如果内容中有什么错误之处或者理解不到位的地方,欢迎理性指正与交流。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值