一、简介
在本文中,我们将探讨如何使用 SQLCoder-7B(我们将在 Amazon SageMaker 上部署的大型语言模型 (LLM))和 LangChain 来执行自然语言查询 (NLQ)。
我们将了解如何使用 LangChain 创建一个管道,提示 LLM 生成 SQL 查询,从 PostgreSQL 数据库检索数据,并将结果作为上下文传递给 LLM 以获得最终响应。
二、SQLCODER
SQLCoder 是大型语言模型 (LLM) 的集合,用于从自然语言高效生成 SQL 查询。
我们将使用 SQLCoder-7B,它基于 Mistral-7B 并针对 SQL 查询生成进行了微调。
根据其创建者的说法"SQLCoder-7B 在自然语言到 SQL 任务中的表现优于 GPT-3.5 Turbo 和其他流行的开源模型。此外,在对特定数据库模式进行微调时,它甚至超过了 GPT-4"。
三、设置环境
1、配置数据库
首先,让我们使用 Amazon RDS 配置 PostgreSQL 数据库。
我们将使用下面的 Terraform 代码片段来完成这项任务:
# ------------------------------------------------------------------------------
# RDS Security group
# ------------------------------------------------------------------------------
resource "aws_security_group" "db_sg" {
name_prefix = local.db_security_group_name_prefix
vpc_id = local.vpc_id
ingress {
from_port = local.db_port
to_port = local.db_port
protocol = "tcp"
cidr_blocks = [local.my_ip_address]
}
egress {
from_port = 0
to_port = 0
protocol = "-1"
cidr_blocks = ["0.0.0.0/0"]
}
}
# ------------------------------------------------------------------------------
# RDS
# ------------------------------------------------------------------------------
module "db" {
source = "terraform-aws-modules/rds/aws"
identifier = local.db_identifier
engine = "postgres"
engine_version = "15.4"
family = "postgres15"
instance_class = local.db_instance_class
allocated_storage = local.db_allocated_storage
db_name = local.db_name
username = local.db_username
port = local.db_port
create_db_subnet_group = true
vpc_security_group_ids = [aws_security_group.db_sg.id]
subnet_ids = local.db_subnet_ids
}
2、初始化数据库
现在,让我们用电子商务数据集中的数据填充新创建的数据库,您可以从 Kaggle 下载 CSV 文件。
下载数据集后,使用以下命令连接到 RDS 数据库:
psql -h <RDS_ENDPOINT> -p <RDS_PORT> -U <DATABASE_USERNAME> -d <DATABASE_NAME> -W
根据提示输入密码。连接后,使用以下 SQL 命令创建销售表:
CREATE TABLE sales (
invoiceno VARCHAR(255),
stockcode VARCHAR(255),
description VARCHAR(255),
quantity INT,
invoicedate TIMESTAMP,
unitprice DECIMAL(10, 2),
customerid INT,
country VARCHAR(50)
);
接下来,使用以下命令将 CSV 文件中的数据复制到销售表中:
\COPY sales(invoiceno, stockcode, description, quantity, invoicedate, unitprice, customerid, country) FROM '/path/to/data.csv' DELIMITER ',' CSV HEADER;
最后,运行一个简单的计数查询,验证数据是否已成功加载:
SELECT COUNT(*) FROM sales;
3、在 Amazon sagemaker 上部署 sqlcoder-7b
如果你已经按照我之前的文章部署了自己的私人 LLM 聊天机器人,只需更新代码,按如下方式部署模型 defog/sqlcoder-7b:
locals {
hugging_face_model_id = "defog/sqlcoder-7b"
}
四、创建LANGCHAIN管道
1、编写剧本
现在,您的数据库已经建立并填充了数据,模型也已部署,我们将创建一个简单的脚本,使用数据库连接和 Amazon SageMaker Endpoint 创建 SQLDatabaseChain。
如果您对检索增强生成(RAG)的概念不了解,请参阅我之前的文章《使用亚马逊 Bedrock 和 LangChain 创建上下文感知的 LLM 聊天机器人》。
使用下面的 Python 脚本可以对数据库中存储的数据执行自然语言查询:
import boto3
import json
from langchain.sql_database import SQLDatabase
from langchain_experimental.sql import SQLDatabaseChain
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint, LLMContentHandler
from typing import Dict
from sqlalchemy.exc import ProgrammingError
# RDS configuration
RDS_DB_NAME = "<RDS_DB_NAME>"
RDS_ENDPOINT = "<RDS_ENDPOINT>"
RDS_USERNAME = "<RDS_USERNAME>"
RDS_PASSWORD = "<RDS_PASSWORD>"
RDS_PORT = "<RDS_PORT>"
RDS_URI = f"postgresql+psycopg2://{RDS_USERNAME}:{RDS_PASSWORD}@{RDS_ENDPOINT}:{RDS_PORT}/{RDS_DB_NAME}"
db = SQLDatabase.from_uri(
RDS_URI,
include_tables=["sales"],
sample_rows_in_table_info=2,
)
# Sagemaker configuration
SAGEMAKER_ENDPOINT_NAME = "<SAGEMAKER_ENDPOINT_NAME>"
MAX_TOKENS = 1024
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
input_str = json.dumps({"inputs": prompt.strip(), "parameters": model_kwargs})
return input_str.encode("utf-8")
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
response = response_json[0]["generated_text"].strip().split("\n")[0]
return response
content_handler = ContentHandler()
sagemaker_client = boto3.client("runtime.sagemaker")
llm = SagemakerEndpoint(
client=sagemaker_client,
endpoint_name=SAGEMAKER_ENDPOINT_NAME,
model_kwargs={
"max_new_tokens": MAX_TOKENS,
"return_full_text": False,
},
content_handler=content_handler,
)
# Chain
db_chain = SQLDatabaseChain.from_llm(
llm,
db,
verbose=True,
)
while True:
user_input = input("Enter a message (or 'exit' to quit): ")
if user_input.lower() == "exit":
break
try:
results = db_chain.run(user_input)
print(results)
except (ProgrammingError, ValueError) as exc:
print(f"\n\n{exc}")
运行此脚本时,它会提示用户输入信息,然后这些信息将通过 SQLDatabaseChain 传递。
此脚本仅供演示之用,可进一步定制。
在与 LLM 交互时,LangChain 可灵活自定义提示,以获得更好的效果。
定制脚本的另一种方法是向 LLM 提供详细的表定义。
这样做可以提供有关被查询表结构的额外上下文,从而帮助 LLM 生成更准确、更相关的 SQL 查询。
2、测试脚本
让我们看看引擎盖下发生了什么。
我运行了脚本,并输入了问题 "最畅销的产品是什么?
在这里,我们可以看到 LangChain 生成了一个提示,其中包含表模式和表中的 2 行。
然后,该提示请求亚马逊 SageMaker 端点根据给定上下文创建 SQL 查询。
结果,模型返回了生成的 SQL 查询,如下图所示:
LangChain 在数据库中执行了查询,得到了以下结果:
有了这些结果,LangChain 再次请求 LLM 模型,在提示符中填写 SQLResult 并请求回答。
模型给出了最终答案,如下图所示:
从这一成功结果来看,脚本似乎能够处理自然语言查询。
要进一步测试脚本,可以考虑尝试其他复杂查询。
必须考虑使用只读用户连接数据库,因为 LLM 有可能生成插入、删除和更改等数据处理语言(DML)查询。
五、结论
这种集成为各种应用打开了大门,从数据分析到能够回答复杂数据库相关查询的聊天机器人。如果您在集成过程中遇到任何难题,或有本指南未涵盖的特殊要求,请随时联系我。我将竭诚为您服务!