Seq2Seq、SeqGAN、Transformer…你都掌握了吗?一文总结文本生成必备经典模型(一)...

2023

点击蓝字 关注我们

关注并星标

从此不迷路

计算机视觉研究院

89bc48d0794caffb6532bc66d8009298.gif

56ff5627aa632c9d098f302def027051.gif

计算机视觉研究院专栏

作者:Edison_G

本专栏将逐一盘点自然语言处理、计算机视觉等领域下的常见任务,并对在这些任务上取得过 SOTA 的经典模型逐一详解。前往 SOTA!模型资源站(sota.jiqizhixin.com)即可获取本文中包含的模型实现代码、预训练模型及 API 等资源。

公众号ID|ComputerVisionGzq

学习群|扫码在主页获取加入方式

转自《机器之心》

本期收录模型速览

aa373c72f3f1d236217a102a04ab37ad.png

c03c28aab17ca55997189f7fc8081c53.png

6b9d02d10a9298ca01e24efe0b31ce8c.png

b02c597496898730a69f7598f6343ac6.png


d492d966aa034498be006315afa1f4a7.png图1. 模型读取一个输入句子 "ABC "并生成 "WXYZ "作为输出句子。该模型在输出句末标记后停止预测。请注意,LSTM是反向读取输入句子的,因为这样做在数据中引入了许多短期的依赖关系,使优化问题更加容易

03b38604d9a005dc56b0857d7068d171.png

b09fd1b68d30c701b816d53e91711fd9.png

df6f454dafc7a00cee2ec86fa0df4bd9.png

c9681e9b9420462bee721af13d269389.png

项目SOTA!平台项目详情页
Seq2Seq(RNN)前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seq2seq

a593137c560b28a6169d06a980636832.png

d428435f40815ab525ec32915c52b348.png
图2.  RNN Encoder–Decoder 架构

d95333150d618cd02b878b250f8bb0a4.png

b4ab0b4fa0c9789725e071ebaed3fe2c.png

h_j的实际激活计算为:


cc0fc78ed703c9b2dde0a790eee66dae.pngb9e21fa4d543d2fd7c8e8fdc7d6cf53a.png

在这种表述中,当复位门接近0时,隐藏状态被强制忽略之前的隐藏状态,只用当前的输入进行复位。这有效地允许隐藏状态放弃任何在未来发现不相关的信息,因此,允许一个更紧凑的表述。

当前 SOTA!平台收录 Seq2Seq(LSTM) 共 2 个模型实现资源,支持的主流框架包含 PyTorch等。

2af9f7a8153cf5e7ccef3f6e94e34663.png

项目SOTA!平台项目详情页
Seq2Seq(LSTM)前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seq2seq-lstm


30ab05faa59e21fa051bfde525b36fca.png


c6e630a42a3dbacd68eae559fb41157c.png
图4. 模型在给定的源句(x_1, x_2, ..., x_T)中生成第t个目标词y_t

85d37eaa79a6c8713273f60b817759a6.png

c340f0c7a2c0802426624911dd4f8c38.png

项目SOTA!平台项目详情页
Seq2Seq+Attention前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seq2seq-attention


e8ad9be9a513f174ec417b7829176940.png

ce6156dcee932fa93f043f9be8f2e229.png

0262263cc53802dbc1d84d474ae142fe.png

fa152c42a690333c1f7ee88e63112d35.png

随机初始化G网络和D网络参数;通过MLE预训练G网络,目的是提高G网络的搜索效率;通过G网络生成部分负样预训练D网络;通过G网络生成sequence用D网络去评判,得到reward:


e58186cb8b40d3e4d35840a7adf980a1.png40b9fff765566376d664cc91693fd58b.png5c93767b588a6c0c37787a0761458e23.png

根据上式(4)计算得到每个action选择得到的奖励并求得累积奖励的期望,以此为loss function,并求导对网络进行梯度更新。其中,下式是标准的D网络误差函数,训练目标是最大化识别真实样本的概率,最小化误识别伪造样本的概率:


22e29feb850f5aab61b3d03c550694f6.png

最后,GAN网络的误差函数如上,循环以上过程直至收敛。


当前 SOTA!平台收录 SeqGAN 共 22 个模型实现资源,支持的主流框架包含 PyTorch、TensorFlow 等。

7264e667ba24ca6245eb3635c94cdfa2.png

项目SOTA!平台项目详情页
SeqGAN前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/seqgan

Attention is all you need 

2017 年,Google 机器翻译团队发表的《Attention is All You Need》完全抛弃了RNN和CNN等网络结构,而仅仅采用Attention机制来完成机器翻译任务,并且取得了很好的效果,注意力机制也成为了研究热点。大多数竞争性神经序列转导模型都有一个编码器-解码器结构。编码器将输入的符号表示序列(x1, ..., xn)映射到连续表示的序列z=(z1, ..., zn)。给定z后,解码器每次生成一个元素的符号输出序列(y1, ..., ym)。在每个步骤中,该模型是自动回归的,在生成下一个符号时,将先前生成的符号作为额外的输入。Transformer遵循这一整体架构,在编码器和解码器中都使用了堆叠式自注意力和点式全连接层,分别在图6的左半部和右半部显示。


5b63697cf7e41d263864891578c74305.png
图6. Transformer架构

编码器。编码器是由N=6个相同的层堆叠而成。每层有两个子层。第一层是一个多头自注意力机制,第二层是一个简单的、按位置排列的全连接前馈网络。在两个子层的每一个周围采用了一个残差连接,然后进行层的归一化。也就是说,每个子层的输出是LayerNorm(x + Sublayer(x)),其中,Sublayer(x)是子层本身实现的函数。为了方便这些残差连接,模型中的所有子层以及嵌入层都会生成尺寸为dmodel=512的输出。

解码器。解码器也是由N=6个相同的层组成的堆栈。除了每个编码器层的两个子层之外,解码器还插入了第三个子层,它对编码器堆栈的输出进行多头注意力。与编码器类似,在每个子层周围采用残差连接,然后进行层归一化。进一步修改了解码器堆栈中的自注意力子层,以防止位置关注后续位置。这种masking,再加上输出嵌入偏移一个位置的事实,确保对位置i的预测只取决于小于i的位置的已知输出。

Attention。注意力函数可以描述为将一个查询和一组键值对映射到一个输出,其中,查询、键、值和输出都是向量。输出被计算为值的加权和,其中分配给每个值的权重是由查询与相应的键的兼容性函数计算的。在Transformer中使用的Attention是Scaled Dot-Product Attention, 是归一化的点乘Attention,假设输入的query q 、key维度为dk,value维度为dv , 那么就计算query和每个key的点乘操作,并除以dk ,然后应用Softmax函数计算权重。Scaled Dot-Product Attention的示意图如图7(左)。

9391c6b13751c41a001ebf6baa3c5d03.png

图7. (左)按比例的点乘法注意力。(右)多头注意力由几个平行运行的注意力层组成

如果只对Q、K、V做一次这样的权重操作是不够的,这里提出了Multi-Head Attention,如图7(右)。具体操作包括:

  1. 首先对Q、K、V做一次线性映射,将输入维度均为dmodel 的Q、K、V 矩阵映射到Q∈Rm×dk,K∈Rm×dk,V∈Rm×dv;

  2. 然后在采用Scaled Dot-Product Attention计算出结果;

  3. 多次进行上述两步操作,然后将得到的结果进行合并;

  4. 将合并的结果进行线性变换。

在完整的架构中,有三处Multi-head Attention模块,分别是:

  1. Encoder模块的Self-Attention,在Encoder中,每层的Self-Attention的输入Q=K=V , 都是上一层的输出。Encoder中的每个位置都能够获取到前一层的所有位置的输出。

  2. Decoder模块的Mask Self-Attention,在Decoder中,每个位置只能获取到之前位置的信息,因此需要做mask,其设置为−∞。

  3. Encoder-Decoder之间的Attention,其中Q 来自于之前的Decoder层输出,K、V 来自于encoder的输出,这样decoder的每个位置都能够获取到输入序列的所有位置信息。

在进行了Attention操作之后,encoder和decoder中的每一层都包含了一个全连接前向网络,对每个位置的向量分别进行相同的操作,包括两个线性变换和一个ReLU激活输出:


24a710e17c40ccf587a6e2bf3fcdd4c3.png

因为模型不包括recurrence/convolution,因此是无法捕捉到序列顺序信息的,例如将K、V按行进行打乱,那么Attention之后的结果是一样的。但是序列信息非常重要,代表着全局的结构,因此必须将序列的token相对或者绝对位置信息利用起来。这里每个token的position embedding 向量维度也是dmodel=512, 然后将原本的input embedding和position embedding加起来组成最终的embedding作为encoder/decoder的输入。其中,position embedding计算公式如下:


4f398e9ab393b3f47502d0c82ce81471.png

其中,pos表征位置,i表征维度。也就是说,位置编码的每个维度对应于一个正弦波。波长形成一个从2π到10000-2π的几何级数。选择这个函数是因为假设它可以让模型很容易地学会通过相对位置来参加,因为对于任何固定的偏移量k,PE_pos+k可以表示为PE_pos的线性函数。

当前 SOTA!平台收录 Transformer 共 9 个模型实现资源,支持的主流框架包含 TensorFlow、PyTorch等。

e65bd8c93709d1aca7de8aaa6334a25e.png

项目SOTA!平台项目详情页
Transformer前往 SOTA!模型平台获取实现资源:https://sota.jiqizhixin.com/project/transformer-2

前往 SOTA!模型资源站(sota.jiqizhixin.com)即可获取本文中包含的模型实现代码、预训练模型及API等资源。 

网页端访问:在浏览器地址栏输入新版站点地址 sota.jiqizhixin.com ,即可前往「SOTA!模型」平台,查看关注的模型是否有新资源收录。 

移动端访问:在微信移动端中搜索服务号名称「机器之心SOTA模型」或 ID 「sotaai」,关注 SOTA!模型服务号,即可通过服务号底部菜单栏使用平台功能,更有最新AI技术、开发资源及社区动态定期推送。

© THE END 

转载请联系本公众号获得授权

ea49b26cb6866d20642344eef03e60f9.gif

计算机视觉研究院学习群等你加入!

计算机视觉研究院主要涉及深度学习领域,主要致力于人脸检测、人脸识别,多目标检测、目标跟踪、图像分割等研究方向。研究院接下来会不断分享最新的论文算法新框架,我们这次改革不同点就是,我们要着重”研究“。之后我们会针对相应领域分享实践过程,让大家真正体会摆脱理论的真实场景,培养爱动手编程爱动脑思考的习惯!

88435e5072f861f3c58e1631824d34dd.jpeg

扫码关注

计算机视觉研究院

公众号ID|ComputerVisionGzq

学习群|扫码在主页获取加入方式

 往期推荐 

🔗

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值