(五)基于文本的QA问答系统——corss-encoder方法

cross-encoder方法

概述

该方法基于GPT模型,利用doc和query构造prompt。将prompt作为模型输入,利用query对应位置的输出结果计算log_softmax,该值可以反映输入词对应输出的预测概率,我们关注query中含有的词汇,对输出中对应query中含有的词汇进行求和,该值可以反映该doc和query相关的程度。

由于doc和query共同编码,因此对于语义的理解程度更好。

初始化模型并构造prompt

tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit",cache_dir = './SGPT-125M-weightedmean-msmarco-specb-bitfit')
model = AutoModelForCausalLM.from_pretrained("Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit",cache_dir = './SGPT-125M-weightedmean-msmarco-specb-bitfit')

prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'

构造上下文并编码

context = prompt.format(doc)
# context_len = len(re.split('[,. ]', context))
context_enc = tokenizer.encode(context, add_special_tokens=False)

对query编码

这里对query编码是因为最后需要计算输出中对应于query中词的log_softmax的sum。

continuation_enc = tokenizer.encode(query, add_special_tokens=False)

初始化model inpt并计算概率

 model_input = torch.tensor(context_enc + continuation_enc[:-1])
continuation_len = len(continuation_enc)
input_len, = model_input.shape

# [seq_len] -> [seq_len, vocab]
logprobs = torch.nn.functional.log_softmax(model(model_input)[0], dim=-1).cpu()
# [seq_len, vocab] -> [continuation_len, vocab]
logprobs = logprobs[input_len - continuation_len:]
# 获取query中对应的词的输出,continuation_enc
# Gather the log probabilities of the continuation tokens -> [continuation_len]
logprobs = torch.gather(logprobs, 1, torch.tensor(continuation_enc).unsqueeze(-1)).squeeze(-1)
score = torch.sum(logprobs)

小结

经过测试,该方法的性能表现优于bi-encoder,但是速度明显较慢。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值