Datawhale AI夏令营 第三期Task1 笔记

逻辑推理赛道baseline代码分析与总结

前言

主要是对baseline的代码进行了代码分析和流程总结,以及个人的一点关于prompt的想法

目录

  1. 引入依赖包
  2. 设置模型和API密钥
  3. API调用和重试机制
  4. 生成Prompt和解析结果
  5. 处理数据
  6. 主函数
  7. 评估和过滤
  8. 辅助函数
  9. 深度学习知识点总结

1 引入依赖包

首先,代码引入了多种Python库和模块,包括并行处理、日志记录、HTTP请求、重试机制等。以下是关键的依赖包及其作用:

from multiprocessing import Process, Manager
import json
import os
from pprint import pprint
import re
from tqdm import tqdm
import random
import uuid
import openai
import tiktoken
import numpy as np
import requests
from retry import retry
from scipy import sparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from loguru import logger
import time
from http import HTTPStatus
import dashscope

2 设置模型和API密钥

接下来,代码设置了一个模型名称和API密钥,并配置了日志记录,这里用的大模型是qwen1.5-1.8b-chat,通过调用API的方式进行的
这里模型不知道可不可以通过huggingface下载使用,还没实验

MODEL_NAME = 'qwen1.5-1.8b-chat'
dashscope.api_key = "your_api_key_here"

logger.remove()  # 移除默认的控制台输出
logger.add("logs/app_{time:YYYY-MM-DD}.log", level="INFO", rotation="00:00", retention="10 days", compression="zip")
  • MODEL_NAME: 指定了使用的预训练模型名称。
  • dashscope.api_key: 设置了API密钥,用于访问模型。
  • logger: 通过loguru库配置日志记录,包括日志文件的保存、轮换和压缩。

3 API调用和重试机制

定义了两个函数:api_retrycall_qwen_apiapi_retry函数实现了API调用的重试机制,call_qwen_api函数负责实际的API调用。

def api_retry(MODEL_NAME, query):
    max_retries = 5
    retry_delay = 60  # in seconds
    attempts = 0
    while attempts < max_retries:
        try:
            return call_qwen_api(MODEL_NAME, query)
        except Exception as e:
            attempts += 1   
            if attempts < max_retries:
                logger.warning(f"Attempt {attempts} failed for text: {query}. Retrying in {retry_delay} seconds...")
                time.sleep(retry_delay)
            else:
                logger.error(f"All {max_retries} attempts failed for text: {query}. Error: {e}")
                raise

def call_qwen_api(MODEL_NAME, query):
    messages = [{'role': 'user', 'content': query}]
    response = dashscope.Generation.call(
        MODEL_NAME,
        messages=messages,
        result_format='message',  # set the result is message format.
    )
    if response.status_code == HTTPStatus.OK:
        return response['output']['choices'][0]['message']['content']
    else:
        print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
            response.request_id, response.status_code,
            response.code, response.message
        ))
        raise Exception()
  • api_retry: 实现了最多5次的重试机制,每次重试间隔60秒。
  • call_qwen_api: 通过dashscope库调用预训练模型,并返回结果。

4 生成Prompt和解析结果

定义了两个函数:get_promptextractget_prompt函数生成适合模型输入的Prompt,extract函数解析API返回的结果。
这里prompt策略可以改进,可以尝试使用普通prompt 和 COT prompt做一下对比,感觉COT prompt效果可能会有提升,可以尝试

def get_prompt(problem, question, options):
    options = '\n'.join(f"{'ABCDEFG'[i]}. {o}" for i, o in enumerate(options))
    prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
### 题目:
{problem}
### 问题:
{question}
{options}
"""
    return prompt

def extract(input_text):
    ans_pattern = re.compile(r"答案是:(.)", re.S)
    problems = ans_pattern.findall(input_text)
    if problems == '':
        return 'A'
    return problems[0]
  • get_prompt: 生成一个适合模型输入的Prompt,包含问题描述、选项等信息。
  • extract: 使用正则表达式解析模型返回的结果,提取答案。

5 处理数据

process_datas函数负责并行处理数据,调用API并解析结果。

def process_datas(datas, MODEL_NAME):
    results = []
    with ThreadPoolExecutor(max_workers=16) as executor:
        future_data = {}
        lens = 0
        for data in tqdm(datas, desc="Submitting tasks", total=len(datas)):
            problem = data['problem']
            for id, question in enumerate(data['questions']):
                prompt = get_prompt(problem, question['question'], question['options'])
                future = executor.submit(api_retry, MODEL_NAME, prompt)
                future_data[future] = (data, id)
                time.sleep(0.6)  # 控制每0.6秒提交一个任务
                lens += 1
        for future in tqdm(as_completed(future_data), total=lens, desc="Processing tasks"):
            data = future_data[future][0]
            problem_id = future_data[future][1]
            try:
                res = future.result()
                extract_response = extract(res)
                data['questions'][problem_id]['answer'] = extract_response
                results.append(data)
            except Exception as e:
                logger.error(f"Failed to process text: {data}. Error: {e}")
    return results
  • ThreadPoolExecutor: 用于并行处理多个任务,提高处理效率。
  • tqdm: 用于显示任务提交和处理的进度。
  • time.sleep(0.6): 控制每0.6秒提交一个任务,防止API请求过于频繁。

6 主函数

main函数负责读取输入数据,调用数据处理函数并写入输出数据。

def main(ifn, ofn):
    if os.path.exists(ofn):
        pass
    data = []
    with open(ifn) as reader:
        for line in reader:
            sample = json.loads(line)
            data.append(sample)
    datas = data
    return_list = process_datas(datas, MODEL_NAME)
    print(len(return_list))
    print("All tasks finished!")
    return return_list
  • 读取输入文件ifn中的数据,并调用process_datas函数处理数据。
  • 将处理后的结果写入输出文件ofn

7 评估和过滤

定义了两个函数:evaluatefilter_problemsevaluate函数评估模型结果,filter_problems函数过滤和去重问题。

def evaluate(ofn):
    data = []
    with open(ofn) as reader:
        for line in reader:
            sample = json.loads(line)
            data.append(sample)
    pse = 0
    cnt = 0
    tot = 0
    for task in data:
        for question in task['questions']:
            if MODEL_NAME in question:
                tot += 1
                cnt += question[MODEL_NAME] == question['answer']
            else:
                pse += 1
    print(cnt, tot, cnt/tot, pse)

def filter_problems(data):
    result = []
    problem_set = set()
    for item in data:
        problem = item['problem']
        if problem in problem_set:
            for existing_item in result:
                if existing_item['problem'] == problem:
                    if has_complete_answer(item['questions']):
                        existing_item['questions'] = item['questions']
                        existing_item['id'] = item['id']
                    break
        else:
            if has_complete_answer(item['questions']):
                result.append(item)
                problem_set.add(problem)
    return result
  • evaluate:

评估模型的正确率和完成度。

  • filter_problems: 过滤重复问题并保留完整答案。

8 辅助函数

定义了一些辅助函数,如has_complete_answerfind_missing_ids来检查答案完整性和查找缺失ID。

def has_complete_answer(questions):
    for question in questions:
        if 'answer' not in question:
            return False
    return True

def find_missing_ids(dict_list):
    extracted_ids = {int(d['id'][-3:]) for d in dict_list}
    all_ids = set(range(500))
    missing_ids = all_ids - extracted_ids
    return sorted(missing_ids)
  • has_complete_answer: 检查每个问题是否包含答案。
  • find_missing_ids: 查找缺失的ID,确保数据完整性。

知识点总结

Prompt工程-提示原则

原则一:编写清晰、具体的指令
应该通过提供尽可能清晰和具体的指令来表达希望模型执行的操作,这将引导模型给出正确的输出,并减少无关或不正确响应的可能。编写清晰的指令不意味着简短的指令,因为在许多情况下,更长的提示实际上更清晰且提供了更多上下文,这实际上可能导致更详细更相关的输出。
原则二:给模型时间去思考
如果模型匆忙地得出了错误的结论,应该尝试重新构思查询,请求模型在提供最终答案之前进行一系列相关的推理。换句话说,如果给模型一个在短时间或用少量文字无法完成的任务,它可能会猜测错误。这种情况对人来说也是一样的。如果让某人在没有时间计算出答案的情况下完成复杂的数学问题,他们也可能会犯错误。因此,在这些情况下,可以指示模型花更多时间思考问题,这意味着它在任务上花费了更多的计算资源。

提分想法

目前使用的prompt是:

prompt = f"""你是一个逻辑推理专家,擅长解决逻辑推理问题。以下是一个逻辑推理的题目,形式为单项选择题。所有的问题都是(close-world assumption)闭世界假设,即未观测事实都为假。请逐步分析问题并在最后一行输出答案,最后一行的格式为"答案是:A"。题目如下:
### 题目:
{problem}
### 问题:
{question}
{options}
"""

考虑可以使用不同的提示策略,如Few-Shots 、 COT 、 SC、 TOT 、 Step-Back
在这里插入图片描述

修改提示策略感觉是最简单的上分方法之一,对于同样的模型,同样的任务,使用不同的 Prompt,输出的结果也有不小的差异,所以这也是上分思路之一。

一般来说,使用Prompt技巧的结果会比不使用任何技巧要好,对于简单的任务并不是叠加所有的技巧就会更好,到达一定结果后,再叠加技巧不会提升效果。

目前针对这个任务来讲,我觉得ToT技巧可能会是最优提示策略,还需要后续实验~~~

  • 9
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值