背景
在RAG中,问题路由情况的一种解决方法:语义路由(Semantic Router)。例如,当我们处理新闻标题时,不同类型的新闻(如体育新闻、军事新闻、日常新闻)可能需要不同的处理逻辑。Langchain 提供了一种简洁且高效的方法来实现这种路由。
代码示例
from langchain_community.utils.math import cosine_similarity
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_community.embeddings import HuggingFaceEmbeddings
# 导入llm模型设置
from config import llm, model_path
# 定义不同类型新闻的模板
sport_template = """你是一个体育新闻记者,负责把有关体育新闻相关的标题进行扩写。如果你不知道该怎么写的话,就回答不知道。
这是标题:{title}"""
war_template = """你是一个军事新闻记者,负责把有关军事新闻相关的标题进行扩写。如果你不知道该怎么写的话,就回答不知道。
这是标题:{title}"""
normal_template = """你是一个日常新闻记者,负责把有关日常新闻相关的标题进行扩写。如果你不知道该怎么写的话,就回答不知道。
这是标题:{title}"""
# 加载模型
# 传入的参数是模型名称,可以是HuggingFace的模型名称,也可以是本地模型的路径
embeddings = HuggingFaceEmbeddings(model_name=model_path)
# 嵌入不同模板的语义向量
prompt_templates = [sport_template, war_template, normal_template]
prompt_embeddings = embeddings.embed_documents(prompt_templates)
# 定义语义路由函数
def prompt_router(input):
query_embedding = embeddings.embed_query(input["title"])
similarity = cosine_similarity([query_embedding], prompt_embeddings)[0]
most_similar = prompt_templates[similarity.argmax()]
if (most_similar == sport_template):
print("去找体育新闻记者")
elif (most_similar == war_template):
print("去找军事新闻记者")
else:
print("去找日常新闻记者")
prompt = PromptTemplate.from_template(most_similar)
return prompt
# 构建处理链
chain = (
{"title": RunnablePassthrough()} # 直接传递输入
| RunnableLambda(prompt_router) # 使用语义路由函数选择模板
| llm # 调用语言模型生成结果
| StrOutputParser() # 解析输出为字符串
)
# 测试代码
chain.invoke("title: 第三次世界大战在何时")
print("---------------------")
chain.invoke("title: 世界杯冠军花落谁家")
print("---------------------")
chain.invoke("title: 鹿晗官宣女友竟然是...")
生成结果:
去找军事新闻记者
---------------------
去找体育新闻记者
---------------------
去找日常新闻记者
Process finished with exit code 0