【代码复现】知识表示学习MHGRN预处理操作(三)


Scalable Multi-Hop Relational Reasoning for Knowledge-Aware Question Answering


将外部知识融入模型进行推理学习,在CommonsenseQA数据集中取得SOTA的效果。
采用的外部知识为conceptNet
代码分为5步:

  1. 下载相关数据集
  2. 对数据集进行预处理
  3. 超参数搜索(可选)
  4. 训练
  5. 评估

本部分主要讲解第2部分----数据集预处理
上一部分,解析了有关csqa训练数据相关处理流程,
本节继续解析该部分内容

基于concepts vocab抽取出训练数据中的concepts

def ground(statement_path, cpnet_vocab_path, pattern_path, output_path, num_processes=1, debug=False):
    # global PATTERN_PATH, CPNET_VOCAB
    # if PATTERN_PATH is None:
    #     PATTERN_PATH = pattern_path
    #     CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)
    # 加载csqa数据,
    sents = []
    answers = []
    with open(statement_path, 'r', encoding='utf-8') as fin:
        lines = [line for line in fin]

    if debug:
        lines = lines[192:195]
        print(len(lines))
    for line in lines:
        if line == "":
            continue
        j = json.loads(line)
        for statement in j["statements"]: # 文本蕴含形式的文本内容
            sents.append(statement["statement"])
        for answer in j["question"]["choices"]:
            ans = answer['text']
            # ans = " ".join(answer['text'].split("_"))
            try:
                assert all([i != "_" for i in ans])
            except Exception:
                print(ans)
            answers.append(ans)
    # 与conceptnet进行匹配
    res = match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path)
    res = prune(res, cpnet_vocab_path) # 剪枝

    # check_path(output_path)
    with open(output_path, 'w', encoding='utf-8') as fout:
        for dic in res:
            fout.write(json.dumps(dic) + '\n')

    print(f'grounded concepts saved to {output_path}')
    print()

其中match_mentioned_concepts是抽取出问题与答案中的concepts,采用多进程进行抽取
有关imap使用方法:参考这篇imap方法解析

def match_mentioned_concepts(sents, answers, num_processes, pattern_path, cpnet_vocab_path):
    input_args = []
    for s, a in zip(sents, answers):
        input_args.append((s, a, pattern_path, cpnet_vocab_path))

    with Pool(num_processes) as p:
        res = list(tqdm(p.imap(ground_qa_pair, input_args), total=len(sents))) #匹配与concept词向量相关的训练语料
    return res

抽取核心函数

def ground_qa_pair(qa_pair):
    s, a, pattern_path, cpnet_vocab_path = qa_pair

    global nlp, matcher, CPNET_VOCAB
    if nlp is None or matcher is None or CPNET_VOCAB is None:
        nlp = spacy.load('en_core_web_sm', disable=['ner', 'parser', 'textcat'])#加载英文模型
        nlp.add_pipe(nlp.create_pipe('sentencizer')) #增加自定义pipeline component
        matcher = load_matcher(nlp, pattern_path) #增加匹配模型
        CPNET_VOCAB = load_cpnet_vocab(cpnet_vocab_path)

    all_concepts = ground_mentioned_concepts(nlp, matcher, s, a) # 抽取文本中出 lemma较短的一个concepts
    answer_concepts = ground_mentioned_concepts(nlp, matcher, a) # 抽取答案中出 lemma较短的一个concepts
    question_concepts = all_concepts - answer_concepts # 得到问题的concepts
    if len(question_concepts) == 0:
        question_concepts = hard_ground(nlp, s, CPNET_VOCAB)  # not very possible

    if len(answer_concepts) == 0:
        answer_concepts = hard_ground(nlp, a, CPNET_VOCAB)  # some case 

    # question_concepts = question_concepts -  answer_concepts
    question_concepts = sorted(list(question_concepts))
    answer_concepts = sorted(list(answer_concepts))
    return {"sent": s, "ans": a, "qc": question_concepts, "ac": answer_concepts}

抽取concepts

其中有关matcher的用法,可以参考:Matcher

def ground_mentioned_concepts(nlp, matcher, s, ans=None): # s 为文本 a 为候选答案
    s = s.lower()
    doc = nlp(s)
    matches = matcher(doc)

    mentioned_concepts = set()
    span_to_concepts = {}

    if ans is not None:
        ans_matcher = Matcher(nlp.vocab)
        ans_words = nlp(ans)
        # print(ans_words)
        ans_matcher.add(ans, None, [{'TEXT': token.text.lower()} for token in ans_words]) # 先分词 后加入matcher中

        ans_match = ans_matcher(doc) # 对s文本进行匹配
        ans_mentions = set()
        for _, ans_start, ans_end in ans_match:
            ans_mentions.add((ans_start, ans_end))

    for match_id, start, end in matches:
        if ans is not None:
            if (start, end) in ans_mentions:
                continue

        span = doc[start:end].text  # the matched span

        # a word that appears in answer is not considered as a mention in the question
        # if len(set(span.split(" ")).intersection(set(ans.split(" ")))) > 0:
        #     continue
        original_concept = nlp.vocab.strings[match_id] # 获取文本内容
        original_concept_set = set()
        original_concept_set.add(original_concept) # 原文文本 set

        # print("span", span)
        # print("concept", original_concept)
        # print("Matched '" + span + "' to the rule '" + string_id)

        # why do you lemmatize a mention whose len == 1?

        if len(original_concept.split("_")) == 1:
            # tag = doc[start].tag_
            # if tag in ['VBN', 'VBG']:

            original_concept_set.update(lemmatize(nlp, nlp.vocab.strings[match_id])) #按照_进行lemma划分

        if span not in span_to_concepts:
            span_to_concepts[span] = set()
        # 分词后与原始文本的对照 
        span_to_concepts[span].update(original_concept_set) # {'accelerator': {'accelerators', 'accelerator'}, 'controller': {'controller'}}

    for span, concepts in span_to_concepts.items():
        concepts_sorted = list(concepts) #转为list
        # print("span:")
        # print(span)
        # print("concept_sorted:")
        # print(concepts_sorted)
        concepts_sorted.sort(key=len) # 按照长度进行排序

        # mentioned_concepts.update(concepts_sorted[0:2])

        shortest = concepts_sorted[0:3] # 选择前三个

        for c in shortest:
            if c in blacklist: # 不需要词会被 continue
                continue

            # a set with one string like: set("like_apples")
            lcs = lemmatize(nlp, c) # 将空格替换为_ 得到lemma 如accelerators的lemma为accelerator
            intersect = lcs.intersection(shortest) # 个人理解为交集
            if len(intersect) > 0:
                mentioned_concepts.add(list(intersect)[0])
            else:
                mentioned_concepts.add(c)

        # if a mention exactly matches with a concept

        exact_match = set([concept for concept in concepts_sorted if concept.replace("_", " ").lower() == span.lower()])
        # print("exact match:")
        # print(exact_match)
        assert len(exact_match) < 2 # 3个只能选1个
        mentioned_concepts.update(exact_match)

    return mentioned_concepts

得到concepts,然后进行prune

去除停用词,只保留concepts vocab中存在的concepts

def prune(data, cpnet_vocab_path):
    # reload cpnet_vocab
    with open(cpnet_vocab_path, "r", encoding="utf8") as fin:
        cpnet_vocab = [l.strip() for l in fin]

    prune_data = []
    for item in tqdm(data):
        qc = item["qc"] # 问题的concepts ['blow', 'change', 'effort']
        prune_qc = []
        for c in qc:
            if c[-2:] == "er" and c[:-2] in qc: # 以er结尾 前两位在qc中 continue
                continue 
            if c[-1:] == "e" and c[:-1] in qc: # 以e结尾 前一位在qc中 continue
                continue
            have_stop = False # 是否存在停用词
            # remove all concepts having stopwords, including hard-grounded ones
            for t in c.split("_"):
                if t in nltk_stopwords:
                    have_stop = True 
            if not have_stop and c in cpnet_vocab: # 存在停用词就过滤掉 并且保存 c 在 concept的vocab中
                prune_qc.append(c)

        ac = item["ac"]
        prune_ac = []
        for c in ac:
            if c[-2:] == "er" and c[:-2] in ac:
                continue
            if c[-1:] == "e" and c[:-1] in ac:
                continue
            all_stop = True
            for t in c.split("_"):
                if t not in nltk_stopwords:
                    all_stop = False
            if not all_stop and c in cpnet_vocab:
                prune_ac.append(c)

        try:
            assert len(prune_ac) > 0 and len(prune_qc) > 0
        except Exception as e:
            pass
            # print("In pruning")
            # print(prune_qc)
            # print(prune_ac)
            # print("original:")
            # print(qc)
            # print(ac)
            # print()
        item["qc"] = prune_qc
        item["ac"] = prune_ac

        prune_data.append(item)
    return prune_data
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值