真正的完全图解Seq2Seq Attention模型

转载公众号:https://mp.weixin.qq.com/s/0k71fKKv2SRLv9M6BjDo4w
原创: 盛源车 机器学习算法与自然语言处理 1周前

https://zhuanlan.zhihu.com/p/40920384

作者:盛源车

知乎专栏:魔法抓的学习笔记

五分钟看懂seq2seq attention模型。

本文通过图片,详细地画出了seq2seq+attention模型的全部流程,帮助小伙伴们无痛理解机器翻译等任务的重要模型。

 

seq2seq 是一个Encoder–Decoder 结构的网络,它的输入一个序列,输出也是一个序列, Encoder 中将一个可变长度的信号序列变为固定长度的向量表达,Decoder 将这个固定长度的向量变成可变长度的目标的信号序列。--简书

 

好了别管了,接下来开始刷图吧。

大框架

 

 

想象一下翻译任务,input是一段英文,output是一段中文。

 

公式(直接跳过看图最佳)

输入: x = (x_1,...,x_{T_x})

输出: y = (y_1,...,y_{T_y})

(1) h_t = RNN_{enc}(x_t, h_{t-1}) , Encoder方面接受的是每一个单词word embedding,和上一个时间点的hidden state。输出的是这个时间点的hidden state。

(2) s_t = RNN_{dec}(\hat{y_{t-1}},s_{t-1}) , Decoder方面接受的是目标句子里单词的word embedding,和上一个时间点的hidden state。

(3) c_i = \sum_{j=1}^{T_x} \alpha_{ij}h_j , context vector是一个对于encoder输出的hidden states的一个加权平均。

(4) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})} , 每一个encoder的hidden states对应的权重。

(5) e_{ij} = score(s_i, h_j) , 通过decoder的hidden states加上encoder的hidden states来计算一个分数,用于计算权重(4)

(6) \hat{s_t} = tanh(W_c[c_t;s_t]), 将context vector 和 decoder的hidden states 串起来。

(7) p(y_t|y_{<t},x) = softmax(W_s\hat{s_t}) ,计算最后的输出概率。

 

详细图

 

左侧为Encoder+输入,右侧为Decoder+输出。中间为Attention。

 

(1) h_t = RNN_{enc}(x_t, h_{t-1}) , Encoder方面接受的是每一个单词word embedding,和上一个时间点的hidden state。输出的是这个时间点的hidden state。

从左边Encoder开始,输入转换为word embedding, 进入LSTM。LSTM会在每一个时间点上输出hidden states。如图中的h1,h2,...,h8。

(2) s_t = RNN_{dec}(\hat{y_{t-1}},s_{t-1}) , Decoder方面接受的是目标句子里单词的word embedding,和上一个时间点的hidden state。

接下来进入右侧Decoder,输入为(1) 句首 &amp;amp;amp;amp;amp;amp;amp;lt;sos&amp;amp;amp;amp;amp;amp;amp;gt;符号,原始context vector(为0),以及从encoder最后一个hidden state: h8。LSTM的是输出是一个hidden state。(当然还有cell state,这里没用到,不提。)

(3) c_i = \sum_{j=1}^{T_x} \alpha_{ij}h_j , context vector是一个对于encoder输出的hidden states的一个加权平均。

(4) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})} , 每一个encoder的hidden states对应的权重。

(5) e_{ij} = score(s_i, h_j) , 通过decoder的hidden states加上encoder的hidden states来计算一个分数,用于计算权重(4)

Decoder的hidden state与Encoder所有的hidden states作为输入,放入Attention模块开始计算一个context vector。之后会介绍attention的计算方法。

下一个时间点

来到时间点2,之前的context vector可以作为输入和目标的单词串起来作为lstm的输入。之后又回到一个hiddn state。以此循环。

 

(6) \hat{s_t} = tanh(W_c[c_t;s_t]), 将context vector 和 decoder的hidden states 串起来。

(7) p(y_t|y_{<t},x) = softmax(W_s\hat{s_t}) ,计算最后的输出概率。

另一方面,context vector和decoder的hidden state合起来通过一系列非线性转换以及softmax最后计算出概率。

 

在luong中提到了三种score的计算方法。这里图解前两种:

Attention score function: dot

 

输入是encoder的所有hidden states H: 大小为(hid dim, sequence length)。decoder在一个时间点上的hidden state, s: 大小为(hid dim, 1)。

第一步:旋转H为(sequence length, hid dim) 与s做点乘得到一个 大小为(sequence length, 1)的分数

第二步:对分数做softmax得到一个合为1的权重

第三步:将H与第二步得到的权重做点乘得到一个大小为(hid dim, 1)的context vector

 

Attention score function: general

 

输入是encoder的所有hidden states H: 大小为(hid dim1, sequence length)。decoder在一个时间点上的hidden state, s: 大小为(hid dim2, 1)。此处两个hidden state的纬度并不一样。

第一步:旋转H为(sequence length, hid dim1) 与 Wa [大小为 hid dim1, hid dim 2)] 做点乘, 再和s做点乘得到一个 大小为(sequence length, 1)的分数

第二步:对分数做softmax得到一个合为1的权重

第三步:将H与第二步得到的权重做点乘得到一个大小为(hid dim, 1)的context vector

 

完结

  • 47
    点赞
  • 252
    收藏
    觉得还不错? 一键收藏
  • 14
    评论
Transformer seq2seq是一种基于Transformer模型seq2seq模型。它使用编码器-解码器架构,输入一个序列,输出另一个序列。与传统的seq2seq模型相比,Transformer seq2seq使用Transformer blocks来代替循环网络。这种模型广泛应用于语音识别、机器翻译、语音翻译、语音合成和聊天机器人训练等NLP问题。它的泛用性很高,但有些特定任务可能需要使用经过定制的模型来获得更好的结果。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [Transformer与seq2seq](https://download.csdn.net/download/weixin_38705558/14034735)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [Transformer神经网络学习笔记——Seq2Seq模型和Transformer](https://blog.csdn.net/qq_50199113/article/details/131562854)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [Transformer、Attentionseq2seq model](https://blog.csdn.net/weixin_41712499/article/details/103199986)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_1"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]
评论 14
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值