【论文阅读】Seq2SQL: 使用强化学习从自然语言生成结构化查询

摘要:关系数据库存储了大量的世界数据。然而,目前访问这些数据需要用户理解查询语言,比如SQL。我们提出了Seq2SQL,一个深度神经网络,用于将自然语言问题翻译成相应的SQL查询。我们的模型使用数据库循环内查询执行的奖励来学习生成查询的策略,其中包含不适合通过交叉熵损失进行优化的无序部分。此外,Seq2SQL利用SQL的结构来减少生成查询的空间,并显著简化生成问题。除了模型之外,我们还发布了WikiSQL,这是一个80654个手工注释的问题和SQL查询示例的数据集,分布在维基百科的24241个表中,比可比数据集大一个数量级。通过将基于策略的强化学习与查询执行环境应用于WikiSQL, Seq2SQL优于最先进的语义解析器,将执行准确率从35.9%提高到59.4%,将逻辑形式准确率从23.4%提高到48.3%。

        主要工作:1、介绍了Seq2SQL,这是一个深度神经网络,用于将自然语言问题翻译成相应的SQL查询。如下图所示,Seq2SQL由三个组件组成,它们利用SQL的结构来减少生成查询的输出空间。此外,它使用基于策略的强化学习(RL)来生成查询条件,这些条件由于其无序性而不适合使用交叉熵损失进行优化。我们使用混合目标训练Seq2SQL,结合交叉熵损失和数据库在循环中查询执行的RL奖励。这些特征允许Seq2SQL在查询生成时获得最先进的结果。

论文代码:

GitHub - tiwarikajal/Seq2SQL--Natural-Language-sentences-to-SQL-Queries

        模型将问题和表的列作为输入。生成相应的SQL查询,该查询在训练期间对数据库执行。执行的结果被用作Reward来返回训练强化学习算法。

        2、发布了WikiSQL数据集,是一个由80654个人工注释的自然语言问题、SQL查询和SQL表实例组成的语料库,这些实例是从维基百科的24241个HTML表中提取出来的。WikiSQL比以前提供逻辑形式和自然语言话语的语义分析数据集大一个数量级。以原始JSON格式和SQL数据库的形式发布了WikiSQL中使用的表。与WikiSQL同时发布了一个查询执行引擎,用于在循环中执行查询以学习策略的数据库。

WikiSQL数据集地址:

https://github.com/salesforce/WikiSQL

        在WikiSQL数据集上,Seq2SQL优于Dong先前最先进的语义解析模型,得到35.9%的执行精度,以及一个增强的指针网络基线,得到53.3%的执行精度。通过利用SQL查询的固有结构和使用实时查询执行的奖励信号应用策略梯度方法,Seq2SQL在WikiSQL上实现了最先进的性能,获得了59.4%的执行精度。

模型输入

        列名令牌由“Pick”、“#”、“CFL”、“Team”等组成;试题符号包括“How”、“many”、“CFL”、“teams”等;SQL令牌包括SELECT, WHERE, COUNT, MIN, MAX等。有了这个增强的输入序列,指针网络可以通过从输入中进行排他选择来生成SQL查询。

        假设有一个包含N个表列的列表和一个问题,如上图所示,并希望生成相应的SQL查询。

        上式表示第j列名称中的单词序列,其中,i表示第j列中的第i个单词,Tj表示第j列中的单词总数。设x^{q}x^{s}分别表示问题中的单词序列和SQL词汇表中的唯一单词集。

        将输入序列x定义为所有列名、问题和SQL词汇表的连接:

        其中,[a,b]表示序列a和b之间的连接,在相邻序列之间添加哨兵标记来划分边界。

模型结构:

1、增强网络

        编码器-解码器结构

        编码器部分

        两层双向长短期记忆网络(LSTM,Hochreiter & Schmidhuber, 1997)对输入x进行编码。编码器的输入是与输入序列中的单词相对应的嵌入。用h^{enc}表示编码器的输出,h_{t}^{enc}是与输入序列中第t个字对应的编码器的输出。encoder将x编码为h^{enc}和每个token对应的h_{t}^{enc}

        解码器部分

        解码器网络采用两层单向LSTM

        在每个解码器步骤s中,解码器LSTM将前一个解码步骤生成的query token y_{s-1}作为输入,并输出状态g_{s}。next,解码器对输入序列的每个位置t产生一个标量注意力分数:

        选择得分最高的输入标记作为生成的SQL查询的下一个标记:

        即通过注意力得分直接选取输入序列中最高注意力的token当做sql序列的当前输出

2、SEQ2SQL

        虽然增强网络可以解决SQL生成问题,但它没有利用SQL固有的结构。通常,SQL查询由三个组件组成。

        第一个组件是聚合操作符(例如count()),或者,不提供聚合操作符。

        第二个组件是SELECT列。

        第三个组件是查询的WHERE子句,包含过滤行所依据的条件。

        Seq2SQL主要包含三个部分,分别对应于聚合操作符、SELECT列和WHERE子句:

        首先,网络对查询的聚合操作进行分类,并添加一个null操作,该操作对应于无聚合。接下来,网络指向输入表中与SELECT列对应的一列。最后,网络使用指针网络生成查询的条件。前两个组件使用交叉熵损失进行监督,而第三代组件使用策略梯度进行训练,以解决查询条件的无序性质。利用SQL的结构允许Seq2SQL进一步减少查询的输出空间,从而获得比Seq2Seq和增强的指针网络更高的性能。

聚合操作:

        首先,对于输入序列中的每个token计算注意力得分:

        将分数向量归一化,生成输入编码的分布:

        κagg是编码器输入h^{enc}和归一化分数的加权总和

        设αagg表示聚合操作(如COUNT, MIN, MAX)和非聚合操作NULL的分数。通过对输入应用多层感知器来计算αagg:

        应用softmax函数得到了可能聚合操作集合上的分布βagg = softmax (αagg)。使用交叉熵损失进行聚合操作。

选择操作:

        SELECT列预测是一个匹配问题,可以使用指针来解决:给定列表示和问题表示的列表,我们选择与问题最匹配的列。为了生成列的表示,首先用LSTM对每个列名进行编码。特定列的表示形式为:

        这里,hj,t表示第j列token的编码器状态。

        为了构造问题的表示,我们使用与κagg相同的体系结构 计算另一个输入表示,但使用未绑定的权重。最后,在列表示上应用多层感知器,以输入表示为条件,计算每个列j的a分数:

        使用softmax函数对分数进行归一化,以产生可能的SELECT列上的分布βsel = softmax(αsel )。我们使用交叉熵损失Lsel来训练SELECT网络。

条件子句:

        可以使用指针解码器来训练WHERE子句。然而,使用交叉熵损失来优化网络有一个限制:查询的WHERE条件可以交换,查询产生相同的结果。假设我们有一个问题“哪些男性大于18岁”,查询SELECT name FROM insurance WHERE age > 18 and gender =“male”和SELECT name FROM insurance WHERE gender =“male”and age > 18。尽管没有精确的字符串匹配,但两个查询都获得了正确的执行结果。如果前者被提供为基础真理,那么使用交叉熵损失来监督生成将错误地惩罚后者。为了解决这个问题,我们应用强化学习来学习一种策略,直接优化执行结果的预期正确性(公式7)。从输出分布中采样以获得下一个Token,而不是在查询生成的每一步都强制学习。在生成过程结束时,我们对数据库执行生成的SQL查询以获得奖励。设y = [y1, y2,…], yT]表示WHERE子句中生成的token序列。设q (y)表示模型生成的查询,qg表示问题对应的基础真值查询。定义reward R (q (y), qg)为:

        loss计算公式如下:

        其中,py(yt)表示在时间步长t期间选择token yt的概率。使用单个蒙特卡罗样本y近似期望梯度。

混合目标函数

        使用梯度下降来训练模型以最小化目标函数。因此,总梯度是预测SELECT列的交叉熵损失、预测聚合操作的交叉熵损失和策略学习的梯度损失的加权和。

对比实验和消融实验:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值