cross-encoder方法
概述
该方法基于GPT模型,利用doc和query构造prompt。将prompt作为模型输入,利用query对应位置的输出结果计算log_softmax,该值可以反映输入词对应输出的预测概率,我们关注query中含有的词汇,对输出中对应query中含有的词汇进行求和,该值可以反映该doc和query相关的程度。
由于doc和query共同编码,因此对于语义的理解程度更好。
初始化模型并构造prompt
tokenizer = AutoTokenizer.from_pretrained("Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit",cache_dir = './SGPT-125M-weightedmean-msmarco-specb-bitfit')
model = AutoModelForCausalLM.from_pretrained("Muennighoff/SGPT-125M-weightedmean-msmarco-specb-bitfit",cache_dir = './SGPT-125M-weightedmean-msmarco-specb-bitfit')
prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'
构造上下文并编码
context = prompt.format(doc)
# context_len = len(re.split('[,. ]', context))
context_enc = tokenizer.encode(context, add_special_tokens=False)
对query编码
这里对query编码是因为最后需要计算输出中对应于query中词的log_softmax的sum。
continuation_enc = tokenizer.encode(query, add_special_tokens=False)
初始化model inpt并计算概率
model_input = torch.tensor(context_enc + continuation_enc[:-1])
continuation_len = len(continuation_enc)
input_len, = model_input.shape
# [seq_len] -> [seq_len, vocab]
logprobs = torch.nn.functional.log_softmax(model(model_input)[0], dim=-1).cpu()
# [seq_len, vocab] -> [continuation_len, vocab]
logprobs = logprobs[input_len - continuation_len:]
# 获取query中对应的词的输出,continuation_enc
# Gather the log probabilities of the continuation tokens -> [continuation_len]
logprobs = torch.gather(logprobs, 1, torch.tensor(continuation_enc).unsqueeze(-1)).squeeze(-1)
score = torch.sum(logprobs)
小结
经过测试,该方法的性能表现优于bi-encoder,但是速度明显较慢。
基于文本的QA问答系统——corss-encoder方法&spm=1001.2101.3001.5002&articleId=124062794&d=1&t=3&u=ad198b817ca24c9b82a40fc672d6f10a)
1007

被折叠的 条评论
为什么被折叠?



