简介
使用llamaindex 做RAG实验过程中,在利用BM25算相似度时,遇到了如下报错;
给出了该报错的解决办法;
报错信息如下
Traceback (most recent call last):
File "C:\Users\js\Desktop\cache\llama_index_demo\recall\bm25_fix_bug.py", line 50, in <module>
res = retriever._retrieve(QueryBundle(q))
File "C:\Users\js\anaconda3\envs\llm\lib\site-packages\llama_index\legacy\retrievers\bm25_retriever.py", line 99, in _retrieve
scored_nodes = self._get_scored_nodes(query_bundle.query_str)
File "C:\Users\js\anaconda3\envs\llm\lib\site-packages\llama_index\legacy\retrievers\bm25_retriever.py", line 91, in _get_scored_nodes
nodes.append(NodeWithScore(node=node, score=doc_scores[i]))
File "C:\Users\js\anaconda3\envs\llm\lib\site-packages\pydantic\v1\main.py", line 341, in __init__
raise validation_error
pydantic.v1.error_wrappers.ValidationError: 1 validation error for NodeWithScore
node
Can't instantiate abstract class BaseNode with abstract methods get_content, get_metadata_str, get_type, hash, set_content (type=type_error)
fix bug
常规的数据加载
from llama_index.core import (
SimpleDirectoryReader
)
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.legacy.retrievers import BM25Retriever
from llama_index.legacy.schema import NodeWithScore, QueryBundle, Node
# Load data
documents = SimpleDirectoryReader(
input_files=["./data/paul_graham_essay.txt"]
).load_data()
# create the sentence window node parser w/ default settings
node_parser = SentenceWindowNodeParser.from_defaults(
window_size=3,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
# Extract nodes from documents
nodes = node_parser.get_nodes_from_documents(documents)
# by default, the node ids are set to random uuids. To ensure same id's per run, we manually set them.
for idx, node in enumerate(nodes):
node.id_ = f"node_{idx}"
重写方法
根据报错信息所示,_get_scored_nodes
方法报错;
解决思路:使用一个子类继承BM25Retriever
, 重写_get_scored_nodes
方法
class JieRetriever(BM25Retriever):
def _get_scored_nodes(self, query: str):
tokenized_query = self._tokenizer(query)
doc_scores = self.bm25.get_scores(tokenized_query)
nodes = []
for i, node in enumerate(self._nodes):
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=doc_scores[i])
nodes.append(node_with_score)
return nodes
if __name__ == '__main__':
retriever = JieRetriever.from_defaults(
# retriever = BM25Retriever.from_defaults(
similarity_top_k=5,
nodes=nodes,
)
q = 'Before college the two main things I worked on, outside of school, were writing and programming'
from pprint import pprint
res = retriever._retrieve(QueryBundle(q))
for item in res:
pprint([item.node.id_, item.score, item.node.get_content()])
关键代码如下,创建一个新的Node:
node_new = Node.from_dict(node.to_dict())
node_with_score = NodeWithScore(node=node_new, score=doc_scores[i])
参考资料
遇到该报错后,花了一段时间意识到这个是llamda-index包的bug,然后便直接在issues中检索该报错信息。果然发现有人和我一样遇到了该问题,然后参考了他人给出的解决办法。