Task2 baseline01 精读 #datawhale夏令营

写在前面

这个是datawhale夏令营2024年第三期的第二次Task的笔记,由于Task2要求的是精读代码,而我在第一篇文章中已经精读的差不多了,这篇文章我就总结升华一下上篇文章的内容,并且补充说明一下上一篇文章说的不太完善的__main__函数后面的部分内容。

好了,我们开始今天的探索之旅

代码的整体概览

datawhale给的这个流程图我觉得比较不错:

所以根据以上的信息,可以推断,上一次我们讲到的__main__后面的部分应该不是没什么用,应该就是所谓的“纠错与结果文件生成部分”根据这个思路,我们复盘一下最后这几段代码:

“纠错与结果文件生成部分”代码的复盘

我们先看看datawhale给出的对这一块的介绍:

纠错与结果生成部分存在的目的是由于目前使用了api调用在线开源大模型,因为网络、模型能力等原因会导致有一些结果会出现缺失。(比如大模型回答时,没有明确给出ABCD的结果,而返回的空值。也有时因为网络retry模块机会使用结束后,依然没有提取到结果会跳过某个问题。)

意思就是说,我们由于是线上调用的,有时候由于网络等问题我们这个大模型得到的答案可能会有所缺失,所以我们应该把这些“答得不好的”给过滤掉。

那么我们来看一下这个部分的代码

去重与排序

将一个问题背景下的所有问题存入同一个字典,并按id序号排序。

def has_complete_answer(questions):
    # 这里假设完整答案的判断逻辑是:每个question都有一个'answer'键
    for question in questions:
        if 'answer' not in question:
            return False
    return True

def filter_problems(data):
    result = []
    problem_set = set()

    for item in data:
        # print('处理的item' ,item)
        problem = item['problem']
        if problem in problem_set:
            # 找到已存在的字典
            for existing_item in result:
                if existing_item['problem'] == problem:
                    # 如果当前字典有完整答案,替换已存在的字典
                    if has_complete_answer(item['questions']):
                        existing_item['questions'] = item['questions']
                        existing_item['id'] = item['id']
                    break
        else:
            # 如果当前字典有完整答案,添加到结果列表
            if has_complete_answer(item['questions']):
                result.append(item)
                problem_set.add(problem)

    return result

return_list = filter_problems(return_list)
# 排序工作 通过id字段后三位代表序号
sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:]))
print(sorted_data)

以上是datawhale给出来的所有的对这段代码的解读。

首先澄清一个事,就是我们从后往前看这个调用顺序,会发现这几个函数调用调用的是main(ifn,ofn)生成的return_list,也就是结果数据,那么我应该更正上一篇文章所提到的这个has_complete_answer函数的传入值里没有'answer'。因为他传入的是结果值,结果值里没有'answer'就在filter函数里把这个传入的data(在后面调用时传入的是result_list)append到这个result里面,返回result,这个result是过滤后最终的result,具体的话我们看后面。

纠错
def find_missing_ids(dict_list):
    # 提取所有序号
    extracted_ids = {int(d['id'][-3:]) for d in dict_list}
    
    # 创建0-500的序号集合
    all_ids = set(range(500))
    
    # 找出缺失的序号
    missing_ids = all_ids - extracted_ids
    
    return sorted(missing_ids)

这个就是利用集合的差来找到缺失的id,我们结合后面的调用或许会有更好的理解

后面的调用
return_list
return_list = filter_problems(return_list)
sorted_data = sorted(return_list, key=lambda x: int(str(x['id'])[-3:]))
print(sorted_data)
# 示例字典列表
dict_list = sorted_data

# 找出缺失的序号
missing_ids = find_missing_ids(dict_list)
print("缺失的序号:", missing_ids)

len(missing_ids)

从这里我们就可以看到,这个按id排好顺序的sorted_data就是find_missing_ids的参数,结合这个,是不就很快理解了上面这个找缺失值的函数了呢

补错

针对空缺的列表我们进行补错,让每个answer字段默认填充为A,当然如果这种补错机制大家觉得不满意可以再送入多线程函数处理一边。

data  = []
with open('round1_test_data.jsonl') as reader:
    for id,line in enumerate(reader):
        if(id in missing_ids):
            sample = json.loads(line)
            for question in sample['questions']:
                question['answer'] = 'A'
            sorted_data.append(sample)
sorted_data = sorted(sorted_data, key=lambda x: int(str(x['id'])[-3:]))
        

这个结合官方给的说明我们得知,这个就是一个把所有在missing_id里的id所对应的问题的答案都设为A,然后插入到原来的sorted_data里。

存储文件
with open('upload.jsonl', 'w') as writer:
    for sample in sorted_data:
        writer.write(json.dumps(sample, ensure_ascii=False))
        writer.write('\n')

最后这个就是创建一个文件(因为'w'模式下如果目录中没有这个文件他会自动创建这个文件),并且把整个回答的答案上传到这个文件中。

好了,以上就是我对这个项目的全部理解,由于水平有限,还希望大家可以多多指正

祝大家学习愉快!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值