seq2seq中的Global Attention机制的三种评分函数的理解—以pytorch为例

seq2seq中的Global Attention机制的三种评分函数的理解—以pytorch为例

1 seq2seq简介

seq2seq模型是机器翻译中常见模型,由编码器(encoder) + 解码器(decoder)组成,其中编解码器都是由一层或者多层RNN组成。seq2seq模型的目标是将可变长度序列作为输入,并使用固定大小的模型将可变长度序 列作为输出返回。具体的实现并不难,可参看论文,官网也有详细的教程。

在这里插入图片描述

2 Attention介绍

2.1 Local Attention 与 Global Attention

seq2seq 解码器的常见问题是,如果我们只依赖于上下文向量来编码整个输入序列的含义,那么我们很可能会丢失信息。尤其是在处理长输入序列时,这极大地限制了我们的解码器的能力。

  • Bahdanau et al 于2015年提出注意力(Attention)机制,即允许解码器关注输入序列的某些部分,而不是在每一步都使用完全固定的上下文,我们将它称为Local Attention

  • Luong et al 2015年 提出了Global Attention 机制,改善了Bahdanau et al. 的基础工作。关键的区别在于: a. Global Attention考虑所有编码器的隐藏状态;

    b. 通过Global Attention,我们仅使用当前步的解码器的隐藏状态来计算注意力权重 ;

这里重点介绍Global Attention的实现,以Pytorch为例。

在这里插入图片描述
在这里插入图片描述

2.2 Global Attention 介绍

观看上图的Attention Layer 模块,蓝色为编码器各个时间步,红色为解码器时间步。 具体操作为:

  • step1: 得到所有编码器时刻的隐藏状态的输出hs:维度为 [time_steps, hidden_size] ;

  • step2: 得到某个时刻的解码器的隐藏状态的输出ht:维度为[1, hidden_size] ;

  • step3: 通过某种评分函数score_f(), 即score_ti = score_f(ht, hs[i, :],) ,得到第ti 个时间步对应的score;

    即ht 与编码器每个时间步的输出的隐藏状态进行 score_f 操作,得到维度为 [time_steps, 1] 的score_t;

  • step4: weight_score = softmax( score_t) ,进行归一化操作,得到每个时间步的权重。维度为:[timesteps, 1] ;

  • step5: 将weight_score 作用于 hs, 即对编码器的输出hs 做一个权重平均:得到 context vector,维度为:[1, hidden_size];如下图中的 c1/c2/c3

在这里插入图片描述

  • step6: 将ht 和 context_vector 进行拼接,即new_ht = concat(ht, context_vector) ,维度为[1, 2* hidden_size];
  • step7: 得到最后输出概率为:pt = softmax(tanh(Wc * new_ht)); 这里面的Wc 可以通过 一个linear 或者FCN层来实现。 下面会有示例。

其实context 与ht 除了可以通过concat进行作用,也可以通过add 结合在一起。

在这里插入图片描述

2.3 评分函数score的的计算方法

上面介绍了Global Attention的方法步骤,其中step3中的评分函数的选取较为重要,可以通过以下三种方式来计算:

在这里插入图片描述

2.3.1 Dot 内积

内积方向较为简单:

import torch

#这里先不考虑batch
time_step = 5  #时间步数,encoder阶段有多少个时间步长
hidden_size = 4  #隐藏层大小

en_output = torch.randn((time_step, hidden_size))  #encoder阶段所有的隐藏状态[5,4]
de_hidden = torch.randn((1, hidden_size))  #decoder阶段的某一个time_step(ti)的隐状态[1,4]

#将de_hidden转置,与en_output相乘,得到score; 即为解码器ti时刻的隐藏状态对应的在编码器的所有输出隐藏状态上的权重
score = torch.matmul(en_output, de_hidden.T)  #[5,1] 

#将该权重 softmax(归一化) 
score = F.softmax(score,dim = 0) #[5,1]

#得到词向量context_vector
context_vector = torch.matmul(score.T, en_output) #[1,4] 
2.3.2 General

与Dot相比,General就多了一个 Wa, 这个Wa 主要通过Linear层来实现。

import torch

#这里先不考虑batch
time_step = 5  #时间步数,encoder阶段有多少个时间步长

#一般情况下,两者的hidden_size一致
en_hidden_size = 4  #编码阶段的hidden_size
de_hidden_size = 3  #解码阶段的hidden_size

en_output = torch.randn((time_step, en_hidden_size))  #encoder阶段所有的隐藏状态[5,4]
de_hidden = torch.randn((1, de_hidden_size ))  #decoder阶段的某一个time_step(ti)的隐状态[1,3]

atten = nn.Linear(en_hidden_size,de_hidden_size) #wa为 en_hidden_size --> de_hidden_size的之间的转换矩阵参数[en_hidden_size, de_hidden_size]=  [4,3]

w = atten(en_output)  #[5,3]

#得到 score; 即为解码器ti时刻的隐藏状态对应的在编码器的所有输出隐藏状态上的权重
score = torch.matmul(w, de_hidden.T)  #[5,1]

#将该权重 softmax(归一化) 
score = F.softmax(score,dim = 0) #[5,1]

#得到词向量context_vector
context_vector = torch.matmul(score.T, en_output) #[1,4]
2.3.3 Concat
import torch

#这里先不考虑batch
time_step = 5  #时间步数,encoder阶段有多少个时间步长
hidden_size = 4

en_output = torch.randn((time_step, hidden_size))  #encoder阶段所有的隐藏状态[5,4]
de_hidden = torch.randn((1, hidden_size))  #decoder阶段的某一个time_step(ti)的隐状态[1,4]

#
atten = torch.nn.Linear(hidden_size * 2, hidden_size) ##wa 为 hidden_size*2 -->  hidden_size之间的转换矩阵参数[hidden_size*2, hidden_size] = [8,4]

#需要将v加入Parameter中去,以便参与梯度更新和参数学习
v = torch.nn.Parameter(torch.FloatTensor(hidden_size)).view(hidden_size, -1) #[4,1]

#即将de_hidden拼接到每个en_output的每个time_step的列中
concat_en_de = torch.zeros(time_step, hidden_size*2) #[5,8] 
for i in range(time_step):
    concat_en_de[i,:hidden_size] = en_output[i,:]
    concat_en_de[i, hidden_size:] = de_hidden[0,:]
  
w = torch.tanh(atten(concat_en_de)) #[5,4] 

score = torch.matmul(w, v) #[5,1]

#将该权重 softmax(归一化) 
score = F.softmax(score,dim = 0) #[5,1]

context_vector = torch.matmul(score.T, en_output) #[1,4]  

总结: 网上用的Global Attention多用前两种score方法。一般经验General方法好于Dot方法。通过Attention注意力机制给Decoder RNN加入额外信息,可以显著提高seq2seq的性能。

3 seq2seq训练问题

知乎文章:白裳 — 完全解析RNN, Seq2Seq, Attention注意力机制 讲到seq2seq训练问题,之前一直没有注意这一点。原文如下:

值得一提的是,在seq2seq结构中将Yt作为下一个时刻的输入Xt+1 <= Yt 进网络,那么某一时刻输出Yt错误就会导致后面全错。在训练时由于网络尚未收敛,这种蝴蝶效应格外明显。

在这里插入图片描述

为了解决这个问题,Google提出了大名鼎鼎的Scheduled Sampling(即在训练中按照一定概率选择输入Yt-1 或者 t-1 时刻对应的真实值,即标签,如下图),既能加快训练速度,也能提高训练精度。

谢谢前人的分享,受益匪浅!

之前在Deecamp 夏令营 AI 降水预测总结 这篇文章中试验了很多类似seq2seq的方法,但是训练的时候其实并没有注意到这种训练过程中产生的 蝴蝶效应 问题。在以后的工作中需要多加注意。

参考链接

seq2seq论文:Sequence to Sequence Learning with Neural Networks

Local Attention 论文: NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE

Global Attention 论文:Effective Approaches to Attention-based Neural Machine Translation

Scheduled Sampling for Sequence Prediction with Recurrent Neural Networks

Pytorch官网教程

完全解析RNN, Seq2Seq, Attention注意力机制

真正的完全图解Seq2Seq Attention模型

  • 5
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值