LLM系列 | 12: 实测OpenAI函数调用功能:以数据库问答为例

11 篇文章 2 订阅

简介

黑云翻墨未遮山,白雨跳珠乱入船。
在这里插入图片描述

小伙伴们好,我是微信公众号《小窗幽记机器学习》的小编:划龙舟的小男孩。紧接前面几篇ChatGPT Prompt工程和应用系列文章:

更多、更新文章欢迎关注微信公众号:小窗幽记机器学习。后续会持续整理模型加速、模型部署、模型压缩、LLM、AI艺术等系列专题,敬请关注。

今天这篇小作文以数据库问答(Text2SQL)为例进一步介绍ChatGPT的函数调用。本文将介绍如何将模型生成的结果输入到自定义的函数中,并利用该功能实现数据库问答功能。为简单起见,这里将使用Chinook 示例数据库

需要特别注意:
生产环境中,生成的SQL可能存在较高风险。因为模型在生成正确的 SQL 这方面暂不完全可靠,小伙伴们评估谨慎使用

数据库相关

环境相关设置及其辅助函数代码请于文末附录部分。以下直接介绍示例数据库相关细节。

获取 Chinook 数据相关的信息:

import sqlite3

conn = sqlite3.connect("data/chinook.db")
print("Opened database successfully")

def get_table_names(conn):
    """Return a list of table names."""
    table_names = []
    tables = conn.execute("SELECT name FROM sqlite_master WHERE type='table';")
    for table in tables.fetchall():
        table_names.append(table[0])
    return table_names


def get_column_names(conn, table_name):
    """Return a list of column names."""
    column_names = []
    columns = conn.execute(f"PRAGMA table_info('{table_name}');").fetchall()
    for col in columns:
        column_names.append(col[1])
    return column_names


def get_database_info(conn):
    """Return a list of dicts containing the table name and columns for each table in the database."""
    table_dicts = []
    for table_name in get_table_names(conn):
        columns_names = get_column_names(conn, table_name)
        table_dicts.append({"table_name": table_name, "column_names": columns_names})
    return table_dicts

获取 db 中的table

可以获取chinook db中有哪些 table:

table_names = get_table_names(conn)
print("table_names=", table_names)

输出结果如下:

table_names= ['albums', 'sqlite_sequence', 'artists', 'customers', 'employees', 'genres', 'invoices', 'invoice_items', 'media_types', 'playlists', 'playlist_track', 'tracks', 'sqlite_stat1']

获取各 table 的schema

database_schema_dict = get_database_info(conn)
database_schema_string = "\n".join(
    [
        f"Table: {table['table_name']}\nColumns: {', '.join(table['column_names'])}"
        for table in database_schema_dict
    ]
)

database_schema_dict结果如下:

[{'table_name': 'albums', 'column_names': ['AlbumId', 'Title', 'ArtistId']},
 {'table_name': 'sqlite_sequence', 'column_names': ['name', 'seq']},
 {'table_name': 'artists', 'column_names': ['ArtistId', 'Name']},
 {'table_name': 'customers',
  'column_names': ['CustomerId',
   'FirstName',
   'LastName',
   'Company',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email',
   'SupportRepId']},
 {'table_name': 'employees',
  'column_names': ['EmployeeId',
   'LastName',
   'FirstName',
   'Title',
   'ReportsTo',
   'BirthDate',
   'HireDate',
   'Address',
   'City',
   'State',
   'Country',
   'PostalCode',
   'Phone',
   'Fax',
   'Email']},
 {'table_name': 'genres', 'column_names': ['GenreId', 'Name']},
 {'table_name': 'invoices',
  'column_names': ['InvoiceId',
   'CustomerId',
   'InvoiceDate',
   'BillingAddress',
   'BillingCity',
   'BillingState',
   'BillingCountry',
   'BillingPostalCode',
   'Total']},
 {'table_name': 'invoice_items',
  'column_names': ['InvoiceLineId',
   'InvoiceId',
   'TrackId',
   'UnitPrice',
   'Quantity']},
 {'table_name': 'media_types', 'column_names': ['MediaTypeId', 'Name']},
 {'table_name': 'playlists', 'column_names': ['PlaylistId', 'Name']},
 {'table_name': 'playlist_track', 'column_names': ['PlaylistId', 'TrackId']},
 {'table_name': 'tracks',
  'column_names': ['TrackId',
   'Name',
   'AlbumId',
   'MediaTypeId',
   'GenreId',
   'Composer',
   'Milliseconds',
   'Bytes',
   'UnitPrice']},
 {'table_name': 'sqlite_stat1', 'column_names': ['tbl', 'idx', 'stat']}]

database_schema_string结果如下:

'Table: albums\nColumns: AlbumId, Title, ArtistId\nTable: sqlite_sequence\nColumns: name, seq\nTable: artists\nColumns: ArtistId, Name\nTable: customers\nColumns: CustomerId, FirstName, LastName, Company, Address, City, State, Country, PostalCode, Phone, Fax, Email, SupportRepId\nTable: employees\nColumns: EmployeeId, LastName, FirstName, Title, ReportsTo, BirthDate, HireDate, Address, City, State, Country, PostalCode, Phone, Fax, Email\nTable: genres\nColumns: GenreId, Name\nTable: invoices\nColumns: InvoiceId, CustomerId, InvoiceDate, BillingAddress, BillingCity, BillingState, BillingCountry, BillingPostalCode, Total\nTable: invoice_items\nColumns: InvoiceLineId, InvoiceId, TrackId, UnitPrice, Quantity\nTable: media_types\nColumns: MediaTypeId, Name\nTable: playlists\nColumns: PlaylistId, Name\nTable: playlist_track\nColumns: PlaylistId, TrackId\nTable: tracks\nColumns: TrackId, Name, AlbumId, MediaTypeId, GenreId, Composer, Milliseconds, Bytes, UnitPrice\nTable: sqlite_stat1\nColumns: tbl, idx, stat'

定义相关函数

定义functions规范

注意,在定义functions规范时要将数据库的schema插入到函数规范中,这对模型来说是很重要的。

functions = [
    {
        "name": "ask_database",
        "description": "请使用以下函数来回答关于音乐的用户问题。输出结果应为一个完整的 SQL 查询。",
        "parameters": {
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": f"""
                            用于提取信息以回答用户问题的 SQL 查询。
                            SQL 查询应使用以下数据库模式编写:
                            {database_schema_string}
                            查询应以纯文本形式返回,而不是 JSON 格式。
                            """,
                }
            },
            "required": ["query"],
        },
    }
]

定义执行SQL语句的函数

# ChatGPT 生成的query会输入到 ask_database
def ask_database(conn, query):
    """Function to query SQLite database with a provided SQL query."""
    try:
        results = str(conn.execute(query).fetchall())
    except Exception as e:
        results = f"query failed with error: {e}"
    return results

# 根据 message["function_call"]["name"] 判断函数调用时机
def execute_function_call(message):
    if message["function_call"]["name"] == "ask_database":
        query = json.loads(message["function_call"]["arguments"])["query"]
        results = ask_database(conn, query)
    else:
        results = f"Error: function {message['function_call']['name']} does not exist"
    return results

示例1:查询曲目数量Top5的艺术家

messages = []
messages.append({"role": "system", "content": "基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。"})
messages.append({"role": "user", "content": "你好,按照曲目数量,排名前5位的艺术家有谁?"})
chat_response = chat_completion_request(messages, functions)
print("chat_response=", chat_response.json())
assistant_message = chat_response.json()["choices"][0]["message"]
print("assistant_message=", assistant_message)
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "name": assistant_message["function_call"]["name"], "content": results})
pretty_print_conversation(messages)

返回结果如下:

chat_response= {'id': 'chatcmpl-7TQdhItIU3FEgvvYliCGRoD0QetgX', 'object': 'chat.completion', 'created': 1687248269, 'model': 'gpt-3.5-turbo-0613', 'choices': [{'index': 0, 'message': {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}, 'finish_reason': 'function_call'}], 'usage': {'prompt_tokens': 448, 'completion_tokens': 67, 'total_tokens': 515}}

assistant_message= {'role': 'assistant', 'content': None, 'function_call': {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}}

最终pretty_print_conversation(messages)的结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

示例2:查询哪个专辑曲目最多

messages.append({"role": "user", "content": "曲目最多的专辑的名称是什么?"})
chat_response = chat_completion_request(messages, functions)
assistant_message = chat_response.json()["choices"][0]["message"]
messages.append(assistant_message)
if assistant_message.get("function_call"):
    results = execute_function_call(assistant_message)
    messages.append({"role": "function", "content": results, "name": assistant_message["function_call"]["name"]})
pretty_print_conversation(messages)

输出结果如下:

system: 基于 Chinook 音乐数据库生成 SQL 查询来回答用户的问题。

user: 你好,按照曲目数量,排名前5位的艺术家有谁?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT artists.Name, COUNT(tracks.TrackId) AS NumTracks FROM artists JOIN albums ON artists.ArtistId = albums.ArtistId JOIN tracks ON albums.AlbumID = tracks.AlbumId GROUP BY artists.ArtistId ORDER BY NumTracks DESC LIMIT 5;"\n}'}

function (ask_database): [('Iron Maiden', 213), ('U2', 135), ('Led Zeppelin', 114), ('Metallica', 112), ('Deep Purple', 92)]

user: 曲目最多的专辑的名称是什么?

assistant: {'name': 'ask_database', 'arguments': '{\n  "query": "SELECT albums.Title, COUNT(tracks.TrackId) AS NumTracks FROM albums JOIN tracks ON albums.AlbumId = tracks.AlbumId GROUP BY albums.AlbumId ORDER BY NumTracks DESC LIMIT 1;"\n}'}

function (ask_database): [('Greatest Hits', 57)]

小结

通过上述示例可以确切感受openai函数调用功能的强大,这也为开发者构建更多稳健服务提供更强的保障。

附录

import json
import openai
import requests
import os
from tenacity import retry, wait_random_exponential, stop_after_attempt
from termcolor import colored

GPT_MODEL = "gpt-3.5-turbo-0613"
openai.api_key  = "sk-xxx"
os.environ['HTTP_PROXY'] = "xxx"
os.environ['HTTPS_PROXY'] = "xxx"

# 调用API的重试机制
@retry(wait=wait_random_exponential(min=1, max=40), stop=stop_after_attempt(3))
def chat_completion_request(messages, functions=None, function_call=None, model=GPT_MODEL):
    headers = {
        "Content-Type": "application/json",
        "Authorization": "Bearer " + openai.api_key,
    }
    json_data = {"model": model, "messages": messages}
    if functions is not None:
        json_data.update({"functions": functions})
    if function_call is not None:
        json_data.update({"function_call": function_call})
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=json_data,
        )
        return response
    except Exception as e:
        print("Unable to generate ChatCompletion response")
        print(f"Exception: {e}")
        return e


# 处理输出,方便阅读
def pretty_print_conversation(messages):
    role_to_color = {
        "system": "red",
        "user": "green",
        "assistant": "blue",
        "function": "magenta",
    }
    formatted_messages = []
    for message in messages:
        if message["role"] == "system":
            formatted_messages.append(f"system: {message['content']}\n")
        elif message["role"] == "user":
            formatted_messages.append(f"user: {message['content']}\n")
        elif message["role"] == "assistant" and message.get("function_call"):
            formatted_messages.append(f"assistant: {message['function_call']}\n")
        elif message["role"] == "assistant" and not message.get("function_call"):
            formatted_messages.append(f"assistant: {message['content']}\n")
        elif message["role"] == "function":
            formatted_messages.append(f"function ({message['name']}): {message['content']}\n")
    for formatted_message in formatted_messages:
        print(
            colored(
                formatted_message,
                role_to_color[messages[formatted_messages.index(formatted_message)]["role"]],
            )
        )
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值