5-2 自然语言处理NLP-seq2seq attention的提出-计算方式及pytorch实现

1、seq2seq attention的论文出处

主要是阅读经典论文《Neural Machine Translation by Jointly Learning to Align and Translate》https://arxiv.org/pdf/1409.0473.pdf

这篇论文中机器翻译采用seq2seq的encoder-decoder模型构建,将输入句子encoding成一个固定长度的向量,然后输入到decoder解码生成译文。

attention机制引入要解决的问题

encoder得到的固定向量直接送入decoder,不利于

(1)较长句子的信息传递 (2)提取decoder中目标单词需要关注原文中的语义信息


2、seq2seq decoder的构成

直接截取一下论文中的讲法吧,如下图:

式子(4)的意思就是解码器生成出来的每个token,由

1、y_{i-1}就是前一个token的特征向量;

2、s_{i} 就是当前的隐状态,这里的s_{i} 也是通过RNN得到的,可以把y_{i-1}c_{i}看成这一步RNN的输入特征向量;

3、c_{i} 就是上下文的特征,没有用attention之前就是encoder得到的固定向量。现在用了attention机制这篇文章中经典的attention机制就是(5)式得到的。

关于c_{i} 的计算中,简单来说就是encoder中各层隐状态h_{j} 与decoder的隐状态s_{i-1}做相似度(可以用一个MLP来实现等,下文再具体说明各种算法),然后求softmax得到权重系数\alpha _{ij},然后做加权和。当然这只是一种做法,实际怎么做可以有别的做法。


3、seq2seq attention的经典形式

Effective Approaches to Attention-based Neural Machine Translation 论文中,将注意力机制大致分为了全局(global)注意力和局部(local)注意力。

全局注意力指的是注意力分布在所有encoder得到隐状态中;局部注意力指注意力只存在于一些隐状态中。例如,图像的average pooling可视为全局注意力;max pooling可视为局部注意力

论文链接:https://arxiv.org/abs/1508.04025

3.1 global attention

直接上论文中的内容好了,这边global attention的意思和上一篇论文的一样,只是论文中用的符号不一样,应该一看就能知道,实际一个意思。

这里decoder的target hidden和encoder中的source hidden的score怎么算呢?文中介绍三种算法:

这里把score称为基于内容的函数,原因是这里的计算只考虑了隐状态间内容的相关性,并不包含时序信息。

(1)dot:transformer中的Q、K就是用dot计算

(2)general:俗称乘法注意力机制

(3)concat:俗称加法注意力机制

当然文中也提到了location based的global attention不过,似乎(不太确定可能本人孤陋寡闻基本没用过)是历史遗留的产物,大家可以自行阅读。

3.2 local attention

此处,出于这篇文章的完整性,提及一下文章还提到了local attention。提出原因是global attention的计算量大。以机器翻译任务为例,如果原文和译文语种差异较大采用global attention,但如果句法上对应较好,可以采用local attention尝试。


4、pytorch实现Seq2Seq中dot型attention的注意力

这里实现一个3.1中dot型的注意力,输入为encoder的各层隐状态encoder_states以及当前的decoder隐状态decoder_state_t,输出为注意力加权后的上下文状态c

class Seq2SeqAttentionMechanism(nn.Module):
    
    def __init__(self):
        super(Seq2SeqAttentionMechanism, self).__init__()
        
    def forward(self, decoder_state_t, encoder_states):
        
        bs, source_length, hidden_size = encoder_states.shape
        
        decoder_state_t = decoder_state_t.unsqueeze(1)
        decoder_state_t = torch.tile(decoder_state_t, dims = (1, source_length, 1))
        
        score = torch.sum(decoder_state_t * encoder_states, dim = -1) #[bs, source_length]
        
        attn_prob = F.softmax(score, dim = -1) #[bs, source_length]
        
        context = torch.sum(attn_prob.unsqueeze(-1) * encoder_states, 1) #[bs, hidden_size]
        
        return attn_prob, context

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch是一种深度学习框架,可以用于实现序列到序列(seq2seq)的机器翻译任务。在seq2seq模型中,编码器将源序列编码为一个固定长度的向量,解码器则将该向量解码为目标序列。为了提高翻译质量,可以使用注意力机制来在解码器中引入上下文信息。 在PyTorch实现seq2seq模型,可以使用nn.Module类来定义模型架构。首先,需要定义编码器和解码器的结构。编码器通常使用循环神经网络(RNN)或卷积神经网络(CNN)进行实现,而解码器则需要使用注意力机制。注意力机制可以使解码器关注输入序列中最相关的部分并根据其进行翻译。 实现注意力机制时,需要计算每个输入序列位置和当前解码器状态之间的相似度。这可以通过计算点积或使用神经网络来实现。然后,可以将相似度作为权重,对输入序列进行加权求和,以计算上下文向量。最后,将上下文向量与当前解码器状态组合在一起,以生成下一个目标序列符号的概率分布。 在训练过程中,可以使用交叉熵损失函数来计算模型输出与正确目标序列之间的差异,并使用反向传播算法更新模型参数。在推理过程中,可以使用贪婪搜索或束搜索来生成翻译结果。 总的来说,PyTorch提供了一种灵活且高效的方式实现seq2seq模型和注意力机制,可以用于各种自然语言处理任务,包括机器翻译、问答系统和对话生成等。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值