对菜鸟很友好@Datawhale AI 夏令营

从零入门AI+逻辑推理 是 Datawhale 2024 年 AI 夏令营第三期的学习活动(AI+逻辑推理”方向),基于“第二届世界科学智能大赛逻辑推理赛道:复杂推理能力评估” 开展学习活动——

  • 适合想 学习大语言模型、想用AI完成逻辑推理任务 的学习者参与

  • 适合纯小白学习和体验大语言模型


    stop1 baseline的运行

  • 按照从零入门 AI 逻辑推理速通手册的步骤来进行操作,进行大模型的申请和第一次baseline的运行。

  • 切记一定要注意Tokens的使用情况,本人在第一次操作时,跑了两遍代码,结果就付费了。


    代码的理解:

  • pip install scipy openai tiktoken retry dashscope logurufrom multiprocessing import Process, Manager
    import json
    import os
    from pprint import pprint
    import re
    from tqdm import tqdm
    import random
    
    import uuid
    import openai
    import tiktoken
    import json
    import numpy as np
    import requests
    from retry import retry
    from scipy import sparse
    #from rank_bm25 import BM25Okapi
    #import jieba
    from http import HTTPStatus
    import dashscope
    
    
    from concurrent.futures import ThreadPoolExecutor, as_completed
    from loguru import logger
    import json
    import time
    from tqdm import tqdm
    
    logger.remove()  # 移除默认的控制台输出
    logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")
    
    MODEL_NAME = 'qwen2-7b-instruct'
    

    前期的准备工作,环境搭建,库文件等。配置和初始化一个Python程序的环境,包括日志记录、进程管理、数据处理、API调用。logger.remove 表示移除不合适的日志操作。api.retry  重试机制,多次调用api

  • def api_retry(MODEL_NAME, query):
        max_retries = 5
        retry_delay = 60  # in seconds
        attempts = 0
        while attempts < max_retries:
            try:
                return call_qwen_api(MODEL_NAME, query)
            except Exception as e:
                attempts += 1   
                if attempts < max_retries:
                    logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                else:
                    logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
                    raise

    用于实现一个带有重试机制的API调用函数,以处理可能发生的临时错误,确保在达到最大重试次数前尝试重新执行API调用。

  • def api_retry(MODEL_NAME, query):
        max_retries = 5
        retry_delay = 60  # in seconds
        attempts = 0
        while attempts < max_retries:
            try:
                return call_qwen_api(MODEL_NAME, query)
            except Exception as e:
                attempts += 1   
                if attempts < max_retries:
                    logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")
                    time.sleep(retry_delay)
                else:
                    logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
                    raise

    通过dashscope的API调用一个名为MODEL_NAME的模型进行推理,发送包含用户查询的JSON消息,处理返回的HTTP响应,如果响应状态码表示成功,则返回模型生成的消息内容;如果失败,则打印错误信息并抛出异常。


    # 这里定义了prompt推理模版
    
    def get_prompt(problem, question, options):
    
        options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))
    
        prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
    
    ### 题目:
    {problem}
    
    ### 问题:
    {question}
    {options}
    """
        # print(prompt)
    return prompt
    
    
    # 这里使用extract抽取模获得抽取的结果
    
    def extract(input_text):
        ans_pattern = re.compile(r"答案是:(.)", re.S)
    
        problems = ans_pattern.findall(input_text)
        # print(problems)
        if(problems == ''):
            return 'A'
        return problems[0]
    
    

    利用prompt模板,其中包括三个参数problem, question, options,从处理该模板的模型响应中提取出预测的答案。


  • def has_complete_answer(questions):
        # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键
        for question in questions:
            if 'answer' not in question:
                return False
        return True
    
    def filter_problems(data):
        result = []
        problem_set = set()
    
        for item in data:
            # print('处理的item' ,item)
            problem = item['problem']
            if problem in problem_set:
                # 找到已存在的字典
                for existing_item in result:
                    if existing_item['problem'] == problem:
                        # 如果当前字典有完整答案,替换已存在的字典
                        if has_complete_answer(item['questions']):
                            existing_item['questions'] = item['questions']
                            existing_item['id'] = item['id']
                        break
            else:
                # 如果当前字典有完整答案,添加到结果列表
                if has_complete_answer(item['questions']):
                    result.append(item)
                    problem_set.add(problem)
    
        return result

    用于从包含问题和答案的数据中筛选出具有完整答案的问题,并去除重复的问题,只保留具有最新或最完整答案的问题记录。


  • def process_datas(datas,MODEL_NAME):
        results = []
        with ThreadPoolExecutor(max_workers=16) as executor:
            future_data = {}
            lasttask = ''
            lastmark = 0
            lens = 0
            for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
                problem = data['problem']
                for id,question in enumerate(data['questions']):
                    prompt = get_prompt(problem, 
                                        question['question'], 
                                        question['options'],
                                        )
    
                    future = executor.submit(api_retry, MODEL_NAME, prompt)
                    
                    future_data[future] = (data,id)
                    time.sleep(0.6)  # 控制每0.5秒提交一个任务
                    lens += 1
            for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
                # print('data',data)
                data = future_data[future][0]
                problem_id = future_data[future][1]
                try:
                    res  = future.result()
                    extract_response = extract(res)
                    # print('res',extract_response)
                    data['questions'][problem_id]['answer'] = extract_response
                    results.append(data)
                    # print('data',data)
                    
                except Exception as e:
                    logger.error(f"Failed to process text: {data}. Error: {e}")
        
        return results
    
    

    并发地向API提交数据集处理的请求,并将答案填充回原始数据结构中,最后返回包含答案的完整数据集


    def find_missing_ids(dict_list):
        # 提取所有序号
        extracted_ids = {int(d['id'][-3:]) for d in dict_list}
        
        # 创建0-500的序号集合
        all_ids = set(range(500))
        
        # 找出缺失的序号
        missing_ids = all_ids - extracted_ids
        
        return sorted(missing_ids)
    
    # 示例字典列表
    dict_list = sorted_data
    
    # 找出缺失的序号
    missing_ids = find_missing_ids(dict_list)
    print("缺失的序号:", missing_ids)
    

    找出缺失的序号,并将缺失的序列打印。


    声明:

  • 本人纯小白,代码水平能力不够,所以在听课的同时记录了老师讲的代码重点理解,结合注释进行个人解释,如果代码有理解错误或不当的地方,欢迎指正
  • 5
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值