DB-GPT-HUB Text-to-SQL微调

DB-GPT-HUB Text-to-SQL微调

项目介绍

DB-GPT-Hub是一个利用LLMs实现Text-to-SQL解析的实验项目,主要包含数据集收集、数据预处理、模型选择与构建和微调权重等步骤,通过这一系列的处理可以在提高Text-to-SQL能力的同时降低模型训练成本,让更多的开发者参与到Text-to-SQL的准确度提升工作当中,最终实现基于数据库的自动问答能力,让用户可以通过自然语言描述完成复杂数据库的查询操作等工作。

本次微调使用的基座模型是Qwen-14B-Chat。

spider数据集,包含训练数据8659条,测试数据1034条。

安装Python3.10

本人在windows上和linux都安装了各个版本的python,参考我的这篇文章,也可以使用构建Docker镜像的方式。

训练

将数据解压到dbgpt_hub/data目录下,即dbgpt_hub/data/spider

生成数据

sh dbgpt_hub/scripts/gen_train_eval_data.sh

在单卡A6000训练,耗时11小时41分钟。

***** train metrics *****
epoch                    =         8.0
train_loss               =      0.0281
train_runtime            = 11:41:19.00
train_samples_per_second =       1.646
train_steps_per_second   =       0.103

训练参数设置

CUDA_VISIBLE_DEVICES=1 python dbgpt_hub/train/sft_train.py \
    --model_name_or_path /soft/Qwen-14B-Chat/ \
    --do_train \
    --dataset example_text2sql_train \
    --max_source_length 2048 \
    --max_target_length 512 \
    --finetuning_type lora \
    --lora_target c_attn \
    --template chatml \
    --lora_rank 64 \
    --lora_alpha 32 \
    --output_dir dbgpt_hub/output/adapter/qwen-14b-sql-lora \
    --overwrite_cache \
    --overwrite_output_dir \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 16 \
    --lr_scheduler_type cosine_with_restarts \
    --logging_steps 50 \
    --save_steps 2000 \
    --learning_rate 2e-4 \
    --num_train_epochs 8 \
    --plot_loss \
    --bf16  >> ${train_log}
    # --bf16#v100不支持bf16

测试

生成sql,如果不需要加载微调的checkpoint,将checkpoint_dir和finetuning_type去掉即可。

单卡A6000,每条sql生成大概1.45s。

sh ./dbgpt_hub/scripts/predict_sft.sh
# dbgpt_hub/scripts/predict_sft.sh
CUDA_VISIBLE_DEVICES=1 python dbgpt_hub/predict/predict.py \
    --model_name_or_path /soft/Qwen-14B-Chat/ \
    --template chatml \
    --finetuning_type lora \
    --checkpoint_dir dbgpt_hub/output/adapter/qwen-14b-sql-lora \
    --predicted_out_filename pred_sql.sql >> ${pred_log}

评估

python dbgpt_hub/eval/evaluation.py --plug_value --input dbgpt_hub/output/pred/pred_sql.sql

微调前后对比

简单sql(248条)中等sql(446条)复杂sql(174条)其他(166条)(1034条)
微调前准确率0.8630.7170.4830.3250.650
微调后准确率0.9350.7850.6030.4160.731

其他数据集

  • WikiSQL: 一个大型的语义解析数据集,由80,654个自然语句表述和24,241张表格的sql标注构成。WikiSQL中每一个问句的查询范围仅限于同一张表,不包含排序、分组、子查询等复杂操作。
  • CHASE: 一个跨领域多轮交互text2sql中文数据集,包含5459个多轮问题组成的列表,一共17940个<query, SQL>二元组,涉及280个不同领域的数据库。
  • BIRD-SQL:数据集是一个英文的大规模跨领域文本到SQL基准测试,特别关注大型数据库内容。该数据集包含12,751对文本到SQL数据对和95个数据库,总大小为33.4GB,跨越37个职业领域。BIRD-SQL数据集通过探索三个额外的挑战,即处理大规模和混乱的数据库值、外部知识推理和优化SQL执行效率,缩小了文本到SQL研究与实际应用之间的差距。
  • CoSQL:是一个用于构建跨域对话文本到sql系统的语料库。它是Spider和SParC任务的对话版本。CoSQL由30k+回合和10k+带注释的SQL查询组成,这些查询来自Wizard-of-Oz的3k个对话集合,查询了跨越138个领域的200个复杂数据库。每个对话都模拟了一个真实的DB查询场景,其中一个工作人员作为用户探索数据库,一个SQL专家使用SQL检索答案,澄清模棱两可的问题,或者以其他方式通知。
  • 按照NSQL的处理模板,对数据集做简单处理,共得到约20w条训练数据

问题解决

  1. poetry安装问题

    个人感觉poetry不太好用,更换了镜像源之后,解析下载缓慢,不知道是不是因为中断了poetry下载依赖,后续的依赖解析一直卡住,也没有超时提示。最后手动使用pip安装。

  2. 评估时内网nltk下载问题

    nltk需要下载相关语料,内网无法下载,外网也容易超时,到这里下载,将packages文件夹重命名为nltk_data,拷贝到报错说明的几个位置中的一个即可。

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值