R2D2:基于可微分树的预训练模型

https://arxiv.org/abs/2107.00967
在一次分享中看到这篇论文,感觉有意思细读了一下
主要是讲基于可微分树的递归transformer来实现具有强解释性的层次预训练语言模型

论文主要章节涉及了三个方面
  • 模型算法,讲解借助transformer实现对句子树结构的提取
  • 算法复杂度的优化,相比于之前提出的tree-LSTM是 n 3 n^3 n3复杂度降低到了线性复杂度
  • 在以上基础上进行大语料的预训练
相关背景知识
  1. 基于CKY算法的语法分析 介绍 博客
乔姆斯基范式(CNF,Chomsky Normal Form)

任何语法都可以转化成一个弱等价的CNF形式,CNF语法都是二分叉
在这里插入图片描述

CYK算法

CYK算法(也称为Cocke–Younger–Kasami算法)是一种用来对 上下文无关文法(CFG,Context Free Grammar)进行语法分析(parsing)的算法。该算法最早由John Cocke, Daniel Younger and Tadao Kasami分别独立提出,其中John Cocke还是1987年度的图灵奖得主。CYK算法是基于动态规划思想设计的一种自底向上语法分析算法。
看过最易懂的博文
代码实现
2. Gumbel-Softmax estimation
在自底向上的计算过程中,每个格子会有多种组合方式,在各种组合方式中,选择概率最大的组合,即argmax函数。但是argmax函数是不可导的,没有办法反向传播。
通过reparameterization对logits的输出拟合为onehot,同时保证梯度可以反向传播
对离散变量再参数化
4. 基于大语料的预训练语言模型的大概套路

模型结构设计
Differentiable Tree

数据结构图
该论文定义了一个类似于CKY形式的可微二叉树解析器
句子 S={s1,s2,s3,…sn}
如上图,每一个格子 T ( i , j ) = < e i , j , p i , j , p ~ i , j > \Tau(i,j)=<e_{i,j},p_{i,j},\tilde{p}_{i,j}> T(i,j)=<ei,j,pi,j,p~i,j>
e i , j e_{i,j} ei,j 是向量表征
p i , j p_{i,j} pi,j 是每一个步所有组合的概率
p ~ i , j \tilde{p}_{i,j} p~i,j是在[ s i s_i si, s j s_j sj]的子树的概率
树的末端节点是 T i , i \Tau_{i,i} Ti,i, e i , i e_{i,i} ei,i以当前输入 s i s_i si的向量初始化, p i , j p_{i,j} pi,j p ~ i , j \tilde{p}_{i,j} p~i,j初始化为1。
在这里插入图片描述

上述公式的k是指( s i s_i si, s j − 1 s_{j-1} sj1)之间的某一分割点(分割点不同,会对应出不同的组合)
第一个公式
f ( . ) f(.) f(.)是我们下一节Recursive Transformer定义的函数, p i , j k p_{i,j}^k pi,jk p ~ i , j k \tilde{p}_{i,j}^k p~i,jk分别指一步中组合的概率和其子树的概率
第二个公式
以K为分割点的子树的概率,是当前组合的概率和左右子树概率的乘积,这个和CKY算法是一致的
第三个公式
Straight Through Gumbel-Softmax ,通过一定方式实现类似argmax函数的可微
p i , j p_{i,j} pi,j p ~ i , j \tilde{p}_{i,j} p~i,j是基于所有分割点得到的 p i , j k p_{i,j}^k pi,jk p ~ i , j k \tilde{p}_{i,j}^k p~i,jk的组合
output: 计算得出权重
第四个公式
通过当前组合与权重系数的乘积计算出 e i , i e_{i,i} ei,i
第五个公式
通过概率向量与权重系数的乘积计算出新的概率向量

Recursive Transformer

Recursive Transformer-based encoder
这个图对应了上一节第一个公式。
中间shape的转换过程看图,不想转述了,最终输出的 p i , j p_{i,j} pi,j R 1 R^1 R1, c i , j k c_{i,j}^k ci,jk R d R^d Rd

Tree Recovery

通过Straight-Through Gumbel-Softmax在每一个cell选择最佳的分割点,Tree( T 1 , n \Tau_{1,n} T1,n), 从树的根节点自顶向下递归操作,选择的最佳分割点还原树的结构,类似于CKY算法最后的回溯过程

Complexity Optimization 复杂度优化

上述的 f ( . ) f(.) f(.)是整个模型的核心计算部分,我们可以通过树的剪枝归并算法来实现对 f ( . ) f(.) f(.)O( n 3 n^3 n3)
复杂度到线性复杂度

算法

在这里插入图片描述

寻找最佳的合并点

在这里插入图片描述

example

在这里插入图片描述
这张图展示了长度为6的句子的处理过程。
m表示设定的剪枝的阈值 T \Tau T 是一个二维数组,用来盛放自底向上计算的所有cell。
上上述图示的三个function:
TREEINDUCTION 是前向计算的过程,调用PRUNING进行剪枝,PRUNING调用FIND寻找最佳消并点。
计算m之下的cell,如上图(b)显示。
当cell的row大于等于m时,还原所有以第m行的节点为root节点的子树,调用PRUNING进行剪枝操作,
剪枝的第一步是找到局部最佳的merge点(上图c),剪掉部分的cell(上图d),返回一个新的 T \Tau T(上图e)
在FIND中,最佳分割点的候选集合需要满足两个条件
(1)在 T \Tau T的第二行
(2)在以第m行的节点为root节点的子树中有被使用到
然后在候选集合中选择(x.p *pl *pr)最高的cell T i , j \Tau_{i,j} Ti,j做为最佳merge点,对应的将 T i , ∗ \Tau_{i,*} Ti, T ∗ , j \Tau_{*,j} T,j剪掉,得到 T 3 \Tau^3 T3

实验

预训练目标:

  1. 学习词汇表征,在实际实验中是对于word piece的表征,选择WikiText-2数据集,长度在128以内的句子,mask词汇,输入左子树和右子树的embedding进行词汇预测
    因为剪枝操作,存在左子树或者右子树为空,以临近的最长子树来替代
    在这里插入图片描述

  2. 无监督成分句法分析
    在 WSJ and CTB 测试集计算F1
    在这里插入图片描述

基于word-piece的word、NP等的召回
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值