KBQA学习记录-项目测试

目录

一、概述

1.总体流程main()函数

2.模型获取函数

3.实体获取函数

4.属性匹配函数

5.数据库连接函数

6.文字匹配函数


一、概述

KBQA包含的模型主要有2个,一个是实体识别,一个是属性映射。当我们训练好模型,然后输入一个问题的时候,程序会大概经历以下流程:

输入:“维纳斯是哪国的?”---文本处理成bert+CRF需要的样子---输入NER模型---识别到实体“维纳斯”---将实体与数据库中三元组进行匹配---提取到属性“国籍”“爱好”“出生日期”等---文本处理成BERT模型需要的样子---和提取的属性一起输入属性映射模型---计算得到属性和“国籍”最相似---返回“国籍”对应的结果。

1.总体流程main()函数

该函数就是按照上面的总体流程进行的,其中我们的模型已经训练好并保存。

def main():

    with torch.no_grad():
        tokenizer_inputs = ()
        tokenizer_kwards = {'do_lower_case': False,
                            'max_len': 64,
                            'vocab_file': './input/config/bert-base-chinese-vocab.txt'}
        ner_processor = NerProcessor()
        sim_processor = SimProcessor()
        tokenizer = BertTokenizer(*tokenizer_inputs, **tokenizer_kwards)


        ner_model = get_ner_model(config_file = './input/config/bert-base-chinese-config.json',
                                  pre_train_model = './output/best_ner.bin',label_num = len(ner_processor.get_labels()))
        ner_model = ner_model.to(device)
        ner_model.eval()

        sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json',
                                  pre_train_model='./output/best_sim.bin',
                                  label_num=len(sim_processor.get_labels()))

        sim_model = sim_model.to(device)
        sim_model.eval()

        while True:
            print("====="*10)
            raw_text = input("问题:\n")
            raw_text = raw_text.strip()
            if ( "quit" == raw_text ):
                print("quit")
                return
            entity = get_entity(model=ner_model, tokenizer=tokenizer, sentence=raw_text, max_len=64)
            print("实体:", entity)
            if '' == entity:
                print("未发现实体")
                continue
            sql_str = "select * from nlpccqa where entity = '{}'".format(entity)
            triple_list = select_database(sql_str)
            triple_list = list(triple_list)
            if 0 == len(triple_list):
                print("未找到 {} 相关信息".format(entity))
                continue
            triple_list = list(zip(*triple_list))
            # print(triple_list)
            attribute_list = triple_list[1]
            answer_list = triple_list[2]
            attribute, answer = text_match(attribute_list, answer_list, raw_text)
            if attribute != '' and answer != '':
                ret = "{}的{}是{}".format(entity, attribute, answer)
            else:
                sim_model = get_sim_model(config_file='./input/config/bert-base-chinese-config.json',
                                          pre_train_model='./output/best_sim.bin',
                                          label_num=len(sim_processor.get_labels()))

                sim_model = sim_model.to(device)
                sim_model.eval()
                attribute_idx = semantic_matching(sim_model, tokenizer, raw_text, attribute_list, answer_list, 64).item()
                if -1 == attribute_idx:
                    ret = ''
                else:
                    attribute = attribute_list[attribute_idx]
                    answer = answer_list[attribute_idx]
                    ret = "{}的{}是{}".format(entity, attribute, answer)
            if '' == ret:
                print("未找到{}相关信息".format(entity))
            else:
                print("回答:",ret)

2.模型获取函数

获取模型后,需要导入相关的参数。

def get_ner_model(config_file,pre_train_model,label_num = 2):
    model = BertCrf(config_name=config_file,num_tags=label_num, batch_first=True)
    model.load_state_dict(torch.load(pre_train_model,map_location=torch.device('cpu')))
    return model.to(device)

def get_sim_model(config_file,pre_train_model,label_num = 2):
    bert_config = BertConfig.from_pretrained(config_file)
    bert_config.num_labels = label_num
    model = BertForSequenceClassification(bert_config)
    model.load_state_dict(torch.load(pre_train_model,map_location=torch.device('cpu')))
    return model

3.实体获取函数

和之前的训练流程大体一致,只是在使用模型的时候,直接拿到logits,并对其进行处理。

def get_entity(model,tokenizer,sentence,max_len = 64):
    pad_token = 0
    sentence_list = list(sentence.strip().replace(' ',''))
    text = " ".join(sentence_list)
    inputs = tokenizer.encode_plus(
        text,
        add_special_tokens=True,
        max_length=max_len,
        truncate_first_sequence=True  # We're truncating the first sequence in priority if True
    )
    input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
    attention_mask = [1] * len(input_ids)
    padding_length = max_len - len(input_ids)
    input_ids = input_ids + ([pad_token] * padding_length)
    attention_mask = attention_mask + ([0] * padding_length)
    token_type_ids = token_type_ids + ([0] * padding_length)
    labels_ids = None

    assert len(input_ids) == max_len, "Error with input length {} vs {}".format(len(input_ids), max_len)
    assert len(attention_mask) == max_len, "Error with input length {} vs {}".format(len(attention_mask), max_len)
    assert len(token_type_ids) == max_len, "Error with input length {} vs {}".format(len(token_type_ids), max_len)

    input_ids = torch.tensor(input_ids).reshape(1,-1).to(device)
    attention_mask = torch.tensor(attention_mask).reshape(1,-1).to(device)
    token_type_ids = torch.tensor(token_type_ids).reshape(1,-1).to(device)
    labels_ids = labels_ids

    model = model.to(device)
    model.eval()
    # 由于传入的tag为None,所以返回的loss 也是None
    ret = model(input_ids = input_ids,
                  tags = labels_ids,
                  attention_mask = attention_mask,
                  token_type_ids = token_type_ids)
    pre_tag = ret[1][0]
    assert len(pre_tag) == len(sentence_list) or len(pre_tag) == max_len - 2

    pre_tag_len = len(pre_tag)
    b_loc_idx = CRF_LABELS.index('B-LOC')
    i_loc_idx = CRF_LABELS.index('I-LOC')
    o_idx = CRF_LABELS.index('O')

    if b_loc_idx not in pre_tag and i_loc_idx not in pre_tag:
        print("没有在句子[{}]中发现实体".format(sentence))
        return ''
    if b_loc_idx in pre_tag:

        entity_start_idx = pre_tag.index(b_loc_idx)
    else:

        entity_start_idx = pre_tag.index(i_loc_idx)
    entity_list = []
    entity_list.append(sentence_list[entity_start_idx])
    for i in range(entity_start_idx+1,pre_tag_len):# 找全实体的字
        if pre_tag[i] == i_loc_idx:
            entity_list.append(sentence_list[i])
        else:
            break
    return "".join(entity_list)

4.属性匹配函数

使用模型的时候,拿到logits,并对其进行处理。

def semantic_matching(model,tokenizer,question,attribute_list,answer_list,max_length):

    assert len(attribute_list) == len(answer_list)

    pad_token = 0
    pad_token_segment_id = 1
    features = []
    for (ex_index, attribute) in enumerate(attribute_list):
        inputs = tokenizer.encode_plus(
            text = question,
            text_pair = attribute,
            add_special_tokens = True,
            max_length = max_length,
            truncate_first_sequence = True
        )
        input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]
        attention_mask = [1] * len(input_ids)

        padding_length = max_length - len(input_ids)
        input_ids = input_ids + ([pad_token] * padding_length)
        attention_mask = attention_mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)

        assert len(input_ids) == max_length, "Error with input length {} vs {}".format(len(input_ids), max_length)
        assert len(attention_mask) == max_length, "Error with input length {} vs {}".format(len(attention_mask),
                                                                                            max_length)
        assert len(token_type_ids) == max_length, "Error with input length {} vs {}".format(len(token_type_ids),
                                                                                            max_length)
        features.append(
            SimInputFeatures(input_ids = input_ids,attention_mask = attention_mask,token_type_ids = token_type_ids)
        )
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
    all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)

    assert all_input_ids.shape == all_attention_mask.shape
    assert all_attention_mask.shape == all_token_type_ids.shape


    dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids)
    sampler = SequentialSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler,batch_size=128)

    data_num = all_attention_mask.shape[0]
    batch_size = 128

    all_logits = None
    for i in range(0,data_num,batch_size):
        model.eval()
        with torch.no_grad():
            inputs = {'input_ids': all_input_ids[i:i+batch_size].to(device),
                      'attention_mask': all_attention_mask[i:i+batch_size].to(device),
                      'token_type_ids': all_token_type_ids[i:i+batch_size].to(device),
                      'labels': None
                      }
            outputs = model(**inputs)
            logits = outputs[0] # 返回对所匹配的每个属性的0以及1的分数
            logits = logits.softmax(dim = -1) # 转换为概率

            if all_logits is None:
                all_logits = logits.clone()
            else:
                all_logits = torch.cat([all_logits,logits],dim = 0)
    pre_rest = all_logits.argmax(dim = -1)
    if 0 == pre_rest.sum():
        return torch.tensor(-1)
    else:
        return pre_rest.argmax(dim = -1)

5.数据库连接函数

def select_database(sql):
    # connect database
    connect = pymysql.connect(user="******",password="******",host="127.0.0.1",port=3306,db="kb_qa",charset="utf8")
    cursor = connect.cursor()  # 创建操作游标
    try:
        # 执行SQL语句
        cursor.execute(sql)
        # 获取所有记录列表
        results = cursor.fetchall()
    except Exception as e:
        print("Error: unable to fecth data: %s ,%s" % (repr(e), sql))
    finally:
        # 关闭数据库连接
        cursor.close()
        connect.close()
    return results

6.文字匹配函数

这里是在对实体匹配到之后,并且从数据库中抽取到属性后,直接看看属性的名字是否在取到的属性中,如果在的话,就直接返回该属性对应的结果即可。

def text_match(attribute_list,answer_list,sentence):

    assert len(attribute_list) == len(answer_list)

    idx = -1
    for i,attribute in enumerate(attribute_list):
        if attribute in sentence:
            idx = i
            break
    if -1 != idx:
        return attribute_list[idx],answer_list[idx]
    else:
        return "",""
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值