transformer模型_Transformer模型细节理解及Tensorflow实现

Transformer模型基于Encoder-Decoder架构,使用Attention而非CNN/RNN。本文深入探讨Transformer的Encoder和Decoder,包括Multi-Head Attention、Feed Forward层,以及Padding和Sequence mask。同时,提供了Tensorflow实现的核心单元代码。
摘要由CSDN通过智能技术生成

Transformer模型使用经典的Encoder-Decoder架构,在特征提取方面抛弃了传统的CNN、RNN,而完全基于Attention机制,在Attention机制上也引入了Self-Attention和Context-Attention,下面结合Transformer架构图和Tensorflow实现了解一下Transformer

3a8e1842b72b7c696f431447230b5f26.png


一、Transformer架构简说

Transformer是Encoder-Decoder架构,因此先整体分为Encoder、Decoder两部分

Encoder由6个相同的层组成,每层包含如下两部分:

第一部分:Multi-Head Attention多头注意力层

第二部分:Feed Forward全连接层

以上两部分都包含Residual Connection残差连接、Add&Norm数据归一化

Decoder也有6个相同的层组成,每层包含如下三部分:

第一部分:Masked Multi-Head Attention 经Sequence Mask处理的多头注意力层

第二部分:Multi-Head Attention

第三部分:Feed Forward全连接层

以上两部分都包含Residual Connection残差连接、Add&Norm数据归一化


二、Transformer模型细节理解

1、Encoder部分的输入是Input Embedding + Positional Embedding,Decoder部分的输入是Output Embedding + Positional Embedding,在MachineTranslation任务中,Input Embedding对应了输入的待翻译文本,Output Embedding对应了翻译后的文本

2、Encoder部分的Multi-Head Attention是self attention,对应输入的Q,k,V均相同是Input Embedding + Positional Embwedding,Decoder部分的Masked Multi-Head

Attention是self attention,对应的Q,K,V相同都是Output Embedding + Positional   

Embedding,Decoder部分的Multi-Head Attention是Context attention,其中K,V相同来自于Encoder的输出memory,Q来自于Docoder词层的OutputEmbedding + Positional Embedding

3、Attention区别

self-attention和context-attention划分的区别是Attention衡量的是一个序列对自身的Attention权重还是一个序列对另一个序列的Attention权重,self-attention即计算自身的Attention权重,而Context-attention计算的是Encoder序列对Decoder序列的Attention权重;

ScaledDot-Product Attention和Multi-Head Attention划分是根据Attention权重计算方式,除了这些还有多种Attention权重计算方式,如下图所示

9c0d4a0520426510d0e30746555825f6.png

这里简单说一下Transformer中使用的Scaled Dot-Product Attention和Multi-Head

Attention

ab2c3283839e9d09ab807a7e0ed5fb3d.png

ScaledDot-Product Attention:通过Q,K矩阵计算矩阵V的权重系数

492758b1978609290c5d17967c0c6aed.png

Multi-HeadAttention:多头注意力是将Q,K,V通过一个线性映射成h个Q,K,V,然后每个都计算Scaled Dot-Product Attention,最后再合起来,Multi-HeadAttention目的还是对特征进行更全面的抽取

4、Residual Connection残差连接原理及作用,先通过下图认识一下残差连接

ce57aa024be30558b73ca48cb60e52d2.png

如上图网络某层对输入x作用后输出是F(x),那么增加残差连接即是在原来F(x)上加上x,输出变成了F(x)+x,即+x操作即是残差连接,残差连接的作用是通过+x,在网络反向传播的时候会多出一个常数项1,防止梯度消失

5、masked区别

Transformer中有两种mask,一种是padding mask,一种是sequence mask

(1) Padding mask每次批次输入序列的长度是不一样的,需要对序列进行对齐操作,具体做法以seq_length为标准,对大于seq_length的序列进行截断,对小于seq_length序列进行填充,但填充部分我们又不希望其被注意,因此填充部分为0,可以在填充位置加上一个负无穷大的数,这样经过softmax后便趋向于0

(2) Sequence masksequencemask是为了使得decodeer不能看到未来的解码信息,因为在transformer中,输出序列是训练的时候是全部一下子全部传入网络中的,不似RNN那种递归的形式,因此对于一个序列来说,在t时刻我们解码输出只应该依赖于t时刻之前的输出,而不应该看到t时刻之后的输出,如果看到了那就不需要解码了,因此我们需要将传入的解码数据进行sequence mask操作。具体操作是产生一个矩阵,矩阵的上三角全为1,对角线

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值