背景
利用LangChain中load_summarize_chain实现网页内容爬取并总结。
亮点:
网页内容过长,导致超过LLM的token限制,使用LangChain中load_summarize_chain实现。
Map-reduce思想:
- 先对长文本进行切分
- map阶段-对每段进行summary
- reduce-对每个map再进行总结
- 实现长文本内容总结
案例实现:
背景:想查找某个产品的生产厂商,需要先去网页查找相关连接,然后分别总结每个连接内容,最后对内容进行汇总。以下为代码:
# 使用google针对产品进行搜索,返回产品列表
import os
from autogen import config_list_from_json
import autogen
import requests
from bs4 import BeautifulSoup
import json
# from langchain.chat_models import ChatOpenAI
from langchain_community.chat_models import ChatOpenAI
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.chains.summarize import load_summarize_chain
from langchain import PromptTemplate
import openai
from dotenv import load_dotenv
# Get API key
load_dotenv()
config_list3 = {"model": "gpt-3.5-turbo","api_key": "sk-xxxxx", "cache_seed": 42}
os.environ["OPENAI_API_KEY"] = "sk-xxxxx"
# summary chain:对每个url输出进行总结
def summary(product,content):
llm = ChatOpenAI(temperature=0, model="gpt-3.5-turbo-16k-0613")
text_splitter = RecursiveCharacterTextSplitter(
separators=["\n\n", "\n"], chunk_size=10000, chunk_overlap=500)
docs = text_splitter.create_documents([content])
map_prompt = """
content is :{text}
Please summarize the Chinese manufacturers of """+ product +""" ,based on the above content and return them in list format.
The returned results should be in the following format(Return strictly in list format.): ["manu1","manu2","manu3"...]
The manufacturers should be from the Chinese market, and it's preferred to use the full name of the manufacturers rather than abbreviations.
"""
combine_prompt = """
content is :{text}
Please summarize the Chinese manufacturers of """+ product +""" ,based on the above content and return them in list format.
The returned results should be in the following format(Return strictly in list format.): ["manu1","manu2","manu3"...]
The manufacturers should be from the Chinese market, and it's preferred to use the full name of the manufacturers rather than abbreviations.
"""
map_prompt_template = PromptTemplate(
template=map_prompt, input_variables=["text"])
combine_prompt_template = PromptTemplate(
template=combine_prompt, input_variables=["text"])
summary_chain = load_summarize_chain(
llm=llm,
chain_type='map_reduce',
map_prompt=map_prompt_template,
combine_prompt=combine_prompt_template,
verbose=False
)
output = summary_chain.run(input_documents=docs, )
# print(output)
return output
# print(summary("GPU","GPU的生产厂家有:七彩虹厂商,技嘉厂商."))
# print(type(summary("GPU","GPU的生产厂家有:七彩虹厂商,技嘉厂商.")))
# 抓取内容:
def scrape(product:str,url: str):
# scrape and summary
print("Scraping website...")
# Define the headers for the request
headers = {
'Cache-Control': 'no-cache',
'Content-Type': 'application/json',
}
# Define the data to be sent in the request
data = {
"url": url
}
# 转json
data_json = json.dumps(data)
# Send the POST request
response = requests.post(
"https://chrome.browserless.io/content?token=2db344e9-a08a-4179-8f48-195a2f7ea6ee", headers=headers, data=data_json)
# Check the response status code
if response.status_code == 200:
soup = BeautifulSoup(response.content, "html.parser")
text = soup.get_text()
# print("CONTENTTTTTT:", text)
# 不论长短都做总结 -> 生成厂商list
# text超长问题
if len(text) > 8000:
text = text[:8000]
output = summary(product,text)
try:
result_list = eval(output)
except Exception as e:
print("生成结果格式转化为list失败,返回为[]")
result_list = []
return result_list
else:
print(f"HTTP request failed with status code {response.status_code}")
# 查找
def search(query):
url = "https://google.serper.dev/search"
payload = json.dumps({
"q": query
})
headers = {
'X-API-KEY': 'do not use mine',
'Content-Type': 'application/json'
}
response = requests.request("POST", url, headers=headers, data=payload)
results = response.json()['organic']
# print(results)
product_manu=[]
for res in results[:10]:
if res["link"]:
res_manu = scrape(query, res["link"])
# 增加判断,如果返回是列表再去扩展
if isinstance(res_manu, list):
product_manu.extend(res_manu)
else:
print("the result of scrape is not list ,pass ")
else:
continue
print("****** product_manu is: \n",product_manu)
return response.json()
search("RTX3050显卡生产厂商")