下面的方法只是一个例子 如果想实现的话需要:
-1.训练好的BERT模型
0.用于finetune的问答语料
1.查找相似度 选出与问题相似度最高的参与匹配
2.控制start和end的前提下找到最大值 也许求和是个办法?n^2的时间复杂度或许也是可接受的?
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import torch
CHINESEBERT_PATH = "ChineseBERT-base"
model = BertForQuestionAnswering.from_pretrained(CHINESEBERT_PATH)
tokenizer = BertTokenizer.from_pretrained(CHINESEBERT_PATH)
question = "晚饭吃什么"
answer_text = "吃第一眼看到的两个菜"
input_ids = tokenizer.encode(question,answer_text)
#101, 3241, 7649, 1391, 784, 720, 102, 1391, 5018, 671, 4706, 4692, 1168, 4638, 697, 702, 5831, 102
#可以看出 101是[CLS] 102是[SEP]
#print(input_ids)
tokens = tokenizer.convert_ids_to_tokens(input_ids)
sep_index = input_ids.index(tokenizer.sep_token_id)#寻找[SEP]切开
num_seg_a = sep_index + 1
num_seg_b = len(input_ids) - num_seg_a
segment_ids = [0]*num_seg_a + [1]*num_seg_b #用0和1 分隔开
outputs = model(torch.tensor([input_ids]), # The tokens representing our input text.
token_type_ids=torch.tensor([segment_ids]), # The segment IDs to differentiate question from answer_text
return_dict=True)
start_scores = outputs.start_logits
end_scores = outputs.end_logits
answer_start = torch.argmax(start_scores)
answer_end = torch.argmax(end_scores)
#这里应该限制start在end前面
print(answer_start,answer_end)
# Combine the tokens in the answer and print it out.
answer = ' '.join(tokens[answer_start:answer_end+1])
#因为可能为空 不妨多roll几次(逃
print('Answer: "' + answer + '"')
#Answer: "吃 什 么 [SEP] 吃 第 一 眼 看 到 的"