4-20 修复bug+QA对去重
修复bug
4月18日,因为电脑忘记设置息屏后不进入睡眠,导致本来预计四个小时应该就可以跑完四年的数据,第二天中午来看还剩一年的没有跑,立马更改息屏设置。
感觉将log的信息输出到文件里应该会比直接输出到控制台快点,于是更改了code将print的信息输出到了txt文件中。于是就放心的继续运行。
但是每次询问其实都会输出当前的json文件、glm4的回答等非常多的print,而这次要跑的是7年的数据,一共有545篇文章。
这就导致了因为电脑wifi意外断开,当我想要打开log文件查看当前已经询问到第几篇文章时,惊讶的发现txt文件大小接近500MB,只能以只读模式打开。
用python搜索errorcode的结果为空,我于是就用json文件的长度/4/5=287,定位到当前至少已经处理了287篇文章,于是接下来处理后面的文章了(但是现在总结报告的时候,我意识到可以重新生成一遍workstation,遍历workstation和log文件,匹配文章名,就可以判断最后一篇生成的文章是什么了)
接下来还遇到了token耗尽的问题。其实这个时候log信息已经精简输出了,只要查看errorcode就可以定位到最后一篇处理的文章名。但是我已经用json文件的长度/4/5得出了结论,并跑了一个小时,才意识到这个解决办法(
因此,01-07年的文章最后得到了三个json文件,里面有部分文章被询问了多次,显然需要根据相似度去重。
同时,20-23这四年的数据也需要将第一次prompt生成的符合条件的QA对加入到第二次生成的QA对中。
QA对去重
接下来就是去重,就是比对两个json文件,如果有相似度>70%的QA就去除。
思路如下:
已知有correct和updatePrompt两个json文件,将correct作为target文件,updatePrompt作为src文件,比较他们的相似度,然后将target文件中相似度低的加入到src文件中,最后输出到output文件中:
代码如下:
import json
from rouge import Rouge
import sys
def json_to_string(json_obj):
"""将JSON对象转换为字符串"""
return json.dumps(json_obj, ensure_ascii=False, sort_keys=True)
def compute_rouge_l_score(reference, summary):
"""
计算ROUGE-L分数
参数:
reference (str): 参考文本。
summary (str): 摘要文本。
返回:
dict: 包含ROUGE-L分数的字典,包括precision, recall, 和fmeasure。
"""
rouge = Rouge()
scores = rouge.get_scores(summary, reference, avg=True)
return scores['rouge-l']
def update_source_with_extension(source, extension,output_file, rouge_l_threshold=0.7):
"""
根据ROUGE-L分数更新source列表,将extension中与source中元素不相似(ROUGE-L分数不大于阈值)的元素添加到source中。
参数:
source (list): 原始的JSON对象列表。
extension (list): 要添加的JSON对象列表。
rouge_l_threshold (float): ROUGE-L分数的阈值,用于决定是否添加元素。
返回:
None (直接修改传入的source列表)
"""
for ext_element in extension:
ext_str = json_to_string(ext_element)
found_similar = False
for src_element in source:
src_str = json_to_string(src_element)
rouge_l = compute_rouge_l_score(src_str, ext_str)
# print(rouge_l)
if rouge_l['f'] > rouge_l_threshold:
found_similar = True
break
if not found_similar:
source.append(ext_element)
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(source, f, ensure_ascii=False, indent=4)
if __name__ == '__main__':
args = sys.argv
src_file = args[1]
target_file = args[2]
output_file = args[3]
with open(src_file, 'r', encoding='utf-8') as src_json:
src_lists = json.load(src_json)
with open(target_file, 'r', encoding='utf-8') as target_json:
target_lists = json.load(target_json)
update_source_with_extension(src_lists, target_lists,output_file)
20-23的数据最后存到了end20to23的json中。
而01-07因为网络和token问题,需要进行两次比较,先将前两部分的生成到end01to07first中,等到最后100篇文章跑完再次去重,最后生成到end01to07中。
最后生成的