Transformer模型结构解析与Python代码实现

本文首发于公众号【DeepDriving】,欢迎关注。

0. 前言

2017年,谷歌研究人员在《Attention Is All You Need》这篇论文中提出了Transformer模型,该模型最初是被用于机器翻译任务中。由于其良好的可并行性和强大的特征提取能力,Transformer模型在随后的几年中被用到自然语言处理、语音识别、计算机视觉等各个领域中,并表现出优异的性能。

本文基于论文的内容解读Transformer模型的各个组成部分,然后用Python实现一个完整的Transformer模型。

1. Transformer模型结构解析

1.1 模型总体架构

Transformer的总体架构如下图所示,模型包含一个编码器和解码器(分别对应下图中的左侧和右侧部分),编码器和解码器都是由一系列堆叠的注意力结构和全连接层组成。

编码器

编码器由 N = 6 N=6 N=6个相同的层组成,每个层又包含两个子层:第一个子层为多头自注意力机制,第二个子层为一个简单的全连接前馈网络。这两个子层都采用了残差连接结构,后面接一个LayerNorm层,也就是说,每个子层的输出为 L a y e r N o r m ( x + S u b L a y e r ( x ) ) LayerNorm(x+SubLayer(x)) LayerNorm(x+SubLayer(x))。因为使用了残差连接结构,模型中所有子层,包括输入的Embedding层,它们的输出维度 d m o d e l d_{model} dmodel都等于512

解码器

解码器也是由 N = 6 N=6 N=6个相同的层组成,除了使用了与编码器相同的子层外,解码器还在其中插入了第三个子层,这个子层对编码器的输出memory执行多头注意力机制。与编码器类似的,解码器的子层也采用残差连接结构,后面再接一个LayerNorm层。需要注意的是,解码器在多头自注意力子层中添加了一个掩码,这种机制可以确保对位置 i i i的预测只能依赖于小于位置 i i i的已知输出。

解码器的输出通过可学习的线性变换层和SoftMax函数转换为预测下一个Token的概率。

1.2 模型结构详解

1.2.1 注意力机制

注意力函数可以描述为将查询(query )和一组键(key)- 值(value)对映射到输出,其中querykey、和value都是向量。注意力函数的功能就是计算value的加权和,其中分配给每个value的权重由与querykey相关的特定函数计算得出。

缩放点积注意力

作者提出的注意力称为缩放点积注意力,它的输入是维度为 d k d_{k} dkquerykey,以及维度为 d v d_{v} dvvalue。对于输入的query,首先计算它与key的点积并除以缩放系数 d k \sqrt{d_{k}} dk ,然后用一个SoftMax函数来计算应用到value上的权重,这个权重再与value做点积运算得到最终结果。

在实际应用中,会把一组query向量打包在一起组成矩阵Q,相应的keyvalue也分别打包为矩阵KV,然后同时用注意力函数进行计算:

A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^{T}}{\sqrt{d_{k}}})V Attention(Q,K,V)=softmax(dk QKT)V

为什么QK的点积结果要除以系数 d k \sqrt{d_{k}} dk ?因为作者发现如果 d k d_{k} dk的值比较大,那么QK点积的结果会产生很大的值,这样经过SoftMax函数后会产生非常小的梯度而不利于模型训练。为了消除这种影响,作者把点积结果除以一个系数 d k \sqrt{d_{k}} dk ,这也是为什么作者把这种注意力称为缩放注意力的原因。

多头注意力

把输入的querykeyvalue用不同的、可学习的线性映射操作分别映射h次,映射后的维度分别为 d k d_{k} dk d k d_{k} dk d v d_{v} dv,然后每个映射的版本再并行地进行注意力计算,产生 d v d_{v} dv维度的输出结果。把这h个输出的结果拼接到一起然后再做一次映射,使得最后输出结果的维度与原始输入相同。作者把这种多次映射再分别进行注意力计算的结构称为多头注意力,它比只使用一个维度为 d m o d e l d_{model} dmodelquerykeyvalue来计算注意力的效果要好很多。

与单头注意力结构相比,多头注意力使得模型具备关注来自不同表示子空间信息的能力,模型的学习能力更强大。多头注意力机制其实就是将输入序列进行多组自注意力处理的过程,可以用公式表示为:

M u l t i H e a d ( Q , K , V ) = C o n c a t ( h e a d 1 , … , h e a d h ) W O w h e r e    h e a d i = A t t e n t i o n ( Q W i Q , K W i K , V W

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

DeepDriving

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值