0.背景
微软发表了一篇论文《LongLLMLingua: Accelerating and Enhancing LLMs in Long Context Scenarios via Prompt Compression》,里面论述了利用PPL来进行RAG的检索排序,十分有趣。于是在自己业务数据上测试了下,虽然最终效果不敌bge rank, 但是表现已远超预期。
1.准备GPT2大模型
下载模型https://huggingface.co/uer/gpt2-chinese-cluecorpussmall/tree/main
2.计算PPL进行排序
# -*- encoding=utf-8 -*-
import os
import json
import math
import pandas as pd
from typing import List
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from transformers import AutoModelForCausalLM, AutoTokenizer
def init_ppl_model():
"""
初始化perplexity
"""
device = "cpu"
base_dir = os.path.dirname(os.path.abspath(__file__))
model_id = os.path.join(base_dir, "./gpt2-chinese-cluecorpussmall")
model = AutoModelForCausalLM.from_pretrained(model_id, is_decoder=True)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
return model, tokenizer
def get_ppl_longllmlingua(text, question, model, tokenizer):
"""
采用llmlingua计算ppl
"""
def get_token_length(text: str, add_special_tokens: bool = True):
return len(
tokenizer(text, add_special_tokens=add_special_tokens).input_ids
)
granularity = "sentence"
ppl = get_ppl_new(
model,
tokenizer,
text + question,
granularity=granularity,
condition_mode="after",
condition_pos_id=get_token_length(text) - 1,
)
return ppl.item()
def get_ppl_new(
model,
tokenizer,
text: str,
granularity: str = "sentence",
input_ids=None,
attention_mask=None,
past_key_values=None,
return_kv=False,
end=None,
condition_mode: str = "none",
condition_pos_id: int = 0,
):
device = "cpu"
if input_ids is None:
tokenized_text = tokenizer(text, return_tensors="pt")
input_ids = tokenized_text["input_ids"].to(device)
attention_mask = tokenized_text["attention_mask"].to(device)
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
else:
past_length = 0
if end is None:
end = input_ids.shape[1]
end = min(end, past_length + 512)
with torch.no_grad():
response = model(
input_ids[:, past_length:end],
attention_mask=attention_mask[:, :end],
past_key_values=past_key_values,
use_cache=True,
)
past_key_values = response.past_key_values
shift_logits = response.logits[..., :-1, :].contiguous()
shift_labels = input_ids[..., past_length + 1: end].contiguous()
# Flatten the tokens
active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1)
active_logits = shift_logits.view(-1, shift_logits.size(-1))[active]
active_labels = shift_labels.view(-1)[active]
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
loss = loss_fct(active_logits, active_labels)
#print(loss)
if condition_mode == "before":
loss = loss[:condition_pos_id]
elif condition_mode == "after":
loss = loss[condition_pos_id:]
res = loss.mean() if granularity == "sentence" else loss
return (res, past_key_values) if return_kv else res
def ppl_rerank(retriever_res, ori_query, model, tokenizer):
"""
利用ppl进行排序
"""
print("ori_query:", ori_query)
retriever_res = json.loads(retriever_res)
pair_list = []
for index, x in enumerate(retriever_res[0]["list"]):
generate_text = ori_query + "我们可以从前面的文档中得到这个问题的答案。"
text_prefix = "{}".format(x["content"][:512 - len(generate_text)])
ppl = get_ppl_score_fixed_prefix(text_prefix, generate_text, model, tokenizer)
#ppl = get_ppl_longllmlingua(text_prefix, generate_text, model, tokenizer)
x["infoMap"]["ppl"] = ppl
retriever_res[0]["list"][index] = x
new_retriever_res = sorted(retriever_res[0]["list"], key=lambda x:x["infoMap"]["ppl"], reverse=False)
retriever_res[0]["list"] = new_retriever_res
res = json.dumps(retriever_res, ensure_ascii=False)
#print("res:", res)
return res
def is_match(x1, x2):
"""
:param x1:
:param x2:
:return:
"""
x1 = set(x1.strip())
x2 = set(x2.strip())
#score = len(x1 & x2) / len(x1 | x2)
score = len(x1 & x2) / len(x2)
if score > 0.95:
return True
else:
return False
def parse_rank_res(rank_res, answer):
"""
:param retriever_res:
:return:
"""
print("ori_query:", answer)
retriever_res = json.loads(rank_res)
true_index = -1
for index, x in enumerate(retriever_res[0]["list"]):
if is_match(x["content"], answer):
true_index = index + 1
break
return true_index
df = pd.read_excel("./基准测试文件_全量QA测试_0119.xlsx")
#df = df[:10]
model, tokenizer = init_ppl_model()
"""
prefix_text = "我不会忘记"
generate_text = "和你一起奋斗的时光。"
get_ppl_longllmlingua(prefix_text, generate_text, model, tokenizer)
"""
df["retrievers_new"] = df.apply(lambda x: ppl_rerank(x["retrievers"], x["Query"], model, tokenizer), axis=1)
df["hit_n"] = df.apply(lambda x: parse_rank_res(x["retrievers_new"], x["预想答案"]), axis=1)
df.to_excel("./基准测试文件_全量QA测试_0119_ppl_rerank_5.xlsx")
3.结论
排序效果一般,目前效果比bge rerank略差,但是这种思想可以取到这么好的排序效果,已经远超预期了。后续可持续关注该方向的进展。
tips:在自己业务数据集上restrictive prompt特别重要,可尝试不同的表述观察实验效果,最后祝你好运~
参考资料
https://huggingface.co/uer/gpt2-chinese-cluecorpussmall/tree/main
https://github.com/microsoft/LLMLingua/blob/main/README.md
https://github.com/microsoft/LLMLingua/blob/main/examples/Retrieval.ipynb