encoder decoder模型_(四十五)通俗易懂理解——Seq2Seq Attention模型

seq2seq 是一个Encoder–Decoder 结构的网络,它的输入一个序列,输出也是一个序列, Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Decoder 将这个固定长度的向量变成可变长度的目标的信号序列。

要了解这个首先要对RNN以及BiLSTM模型有一个清晰的理解,需要再回顾的查看以下文章:

梦里寻梦:(五)通俗易懂理解——BiLSTM​zhuanlan.zhihu.com

接下去要对Encoder-Decoder模型有一个清晰的认识,其在机器翻译上效果十分突出。

befc6a2b190402e70cdbc864d95a092f.png

如图中所展示,我们要翻译“知识就是力量。”这句话。Encoder是一个RNN,将要翻译的话转换成向量特征,输入到Decoder中。

简而言之,就是输入“知识就是力量”,然后经过神经网络后,输出一个向量,这个向量包含着丰富的语义信息,也即所谓的编码。然后再将该编码输入一套神经网络,最终输出“knowledge is power”。

7375c6cdf260f8d3f593b4a238b769cc.png

这是Encoder,一个RNN,C是RNN从输入x_1,x_2,x_3,x_4中提取的向量,或者说对x_1,x_2,x_3,x_4进行一个编码,得到c有多种方式,最简单的方法就是把Encoder的最后一个隐状态赋值给c,还可以对最后的隐状态做一个变换得到c,也可以对所有的隐状态做变换。

504abbd6640dd31469af0bddc71ed02d.png

59a2a800e3fc85e3b05f0c70722386de.png

获得C以后,就使用另一个RNN,Decoder,来对编码C进行解码,或者说根据向量C来学习获得正确的输出。上面两图中是两种输入方式,将C当做之前的初始状态h0输入到Decoder中和将C当做每一步的输入。

接下来讲述Seq2Seq Attention

大框架如下图所示

4b8cc90861ea2f2565338353f87151a2.png
想象一下翻译任务,input是一段英文,output是一段中文。

接下来按照步骤进行解说,注意查看图片下方的解释说明。

(1)

, Encoder方面接受的是每一个单词word embedding,和上一个时间点的hidden state。输出的是这个时间点的hidden state。

95d54a8235ec558360e578df104242ce.png
从左边Encoder开始,输入转换为word embedding, 进入LSTM。LSTM会在每一个时间点上输出hidden states。如图中的h1,h2,...,h8。

(2)

, Decoder方面接受的是目标句子里单词的word embedding,和上一个时间点的hidden state。

c497b0766e40b5944e49b3d4b4fef77d.png
接下来进入右侧Decoder,输入为(1) 句首符号<sos>,原始context vector(为0),以及从encoder最后一个hidden state: h8。LSTM的是输出是一个hidden state。(当然还有cell state,这里没用到,不提。)

我们将获得以下信息,也就是上图中hidden states层,包括encode和decode,但是decode此时只有一个信息:

d0c09c0eb40b253f3a3321f01204f9da.png

紧接着我们需要计算权重得分。

在luong中提到了三种score的计算方法。这里图解前两种:

425f4758daab466c566813c48fdff744.png

Attention score function: dot

a77f2c7e9364936d9fd9e55a0679d89e.png

输入是encoder的所有hidden states H: 大小为(hid dim, sequence length)。decoder在一个时间点上的hidden state, s: 大小为(hid dim, 1)。

第一步:旋转H为(sequence length, hid dim) 与s做点乘得到一个 大小为(sequence length, 1)的分数

(3)

, 通过decoder的hidden states加上encoder的hidden states来计算一个分数。

第二步:对分数做softmax得到一个合为1的权重

(4)

, 每一个encoder的hidden states对应的权重。

第三步:将H与第二步得到的权重做点乘得到一个大小为(hid dim, 1)的context vector

(5)

, context vector是一个对于encoder输出的hidden states的一个加权平均。

Attention score function: general

4ab180d9791f09a3c3b4a0f9e42e2b26.png

输入是encoder的所有hidden states H: 大小为(hid dim1, sequence length)。decoder在一个时间点上的hidden state, s: 大小为(hid dim2, 1)。此处两个hidden state的纬度并不一样。

第一步:旋转H为(sequence length, hid dim1) 与 Wa [大小为 hid dim1, hid dim 2)] 做点乘, 再和s做点乘得到一个 大小为(sequence length, 1)的分数

(3)

, 通过decoder的hidden states加上encoder的hidden states来计算一个分数。

第二步:对分数做softmax得到一个合为1的权重

(4)

, 每一个encoder的hidden states对应的权重。

第三步:将H与第二步得到的权重做点乘得到一个大小为(hid dim, 1)的context vector

(5)

, context vector是一个对于encoder输出的hidden states的一个加权平均。

我们得到了如下信息:

fe0e711ded10744a0ec2ebec80ccae0b.png
Decoder的hidden state与Encoder所有的hidden states作为输入,放入Attention模块开始计算一个context vector。之后会介绍attention的计算方法。

从上文步骤我们得到了如下信息,第一个context的信息:

1230b0a75fbb0c782e175868475b60ad.png

那么接下去就是重复上面一个过程,生成context的第二个第三个信息等到。

下一个时间点

73a8353a36193ea7e1c3276e81c3beea.png
来到时间点2,之前的context vector可以作为输入和目标的单词串起来作为lstm的输入。之后又回到一个hiddn state。以此循环。

(6)

, 将context vector 和 decoder的hidden states 串起来。

此时我们获得了如下信息:

17fcd086d5287d3c0801b0baf831cabe.png

(7)

,计算最后的输出概率。

37faca8596a05605180d80f04b97bdb9.png
另一方面,context vector和decoder的hidden state合起来通过一系列非线性转换以及softmax最后计算出概率。

完结

整个完整公式如下:

输入:

输出:

(1)

, Encoder方面接受的是每一个单词word embedding,和上一个时间点的hidden state。输出的是这个时间点的hidden state。

(2)

, Decoder方面接受的是目标句子里单词的word embedding,和上一个时间点的hidden state。

(3)

, context vector是一个对于encoder输出的hidden states的一个加权平均。

(4)

, 每一个encoder的hidden states对应的权重。

(5)

, 通过decoder的hidden states加上encoder的hidden states来计算一个分数,用于计算权重(4)

(6)

, 将context vector 和 decoder的hidden states 串起来。

(7)

,计算最后的输出概率。

详细图

170742ede5c274943d35625c45b57efe.png
左侧为Encoder+输入,右侧为Decoder+输出。中间为Attention。

很感谢原文作者的无私奉献,让我对这个模型有了进一步的了解,基于我自身看文章时候的困惑,对文章对顺序等方面做了一定的修改。也相信大家看完之后能够更加简单明了地理解该知识点。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值