attention背景
在seq2seq结构下,encoder-decoder模型中,模型首先将输入序列encode到固定长度的向量 h h h中,然后在decoder中将 h h h解码为输出序列。如下图所示:
在这种结构中,输入序列的信息被压缩到了向量 h h h中,模型根据 h h h和当前时刻的输出确定下一个时刻的输出。
随着序列长度的增加,当序列长度很长时,这种信息压缩方式会造成序列中较早时刻输入的信息损失。因此,为了解决这一问题,attnetion机制被引入到RNN中。
此外,我们在处理自然语言,希望在decoder的不同时刻,能将attention放在encoder的不同时刻的输入上。如翻译"今天天气真好"–“It’s a nice day today”,在输出"It’s a nice day"后,我们希望将attetion放在“今天”这个词汇上,完成翻译,而attention机制很好的实现了这一点。
attention模型
本文重点介绍两种attention机制,即Bahdanau Attention和Luong Attention
Bahdanau Attention
Bahdanau 提出一种基于encoder-decoder架构的attention机制。
论文地址:https://arxiv.org/pdf/1409.0473.pdf
模型架构如下:
原理如下:
首先定义输出条件概率如下:
p
(
y
i
∣
y
1
,
.
.
.
,
y
i
−
1
,
x
)
=
g
(
y
i
−
1
,
s
i
,
c
i
)
p(y_i | y_1,...,y_{i-1},{\bf{x}}) = g(y_{i-1},s_i,c_i)
p(yi∣y1,...,yi−1,x)=g(yi−1,si,ci)
其中,
s
i
s_{i}
si为decoder中
i
i
i时刻的隐状态,计算公式如下:
s
i
=
f
(
s
i
−
1
,
y
i
−
1
,
c
i
)
s_i=f(s_{i-1},y_{i-1},c_i)
si=f(si−1,yi−1,ci)
c
i
c_i
ci为
i
i
i时刻的上下文向量,计算公式如下:
1,首先计算decoder当中第
i
i
i个位置与encoder中第
j
j
j个位置的匹配度:
e
i
j
=
a
(
s
i
−
1
,
h
j
)
e_{ij}=a(s_{i-1},h_j)
eij=a(si−1,hj)
2,其次将
e
i
j
e_{ij}
eij进行softmax归一化,映射到概率空间,得到encoder每个位置的权重
α
i
j
=
e
x
p
(
e
i
j
)
∑
k
=
1
T
x
e
x
p
(
e
i
k
)
\alpha_{ij}=\frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})}
αij=∑k=1Txexp(eik)exp(eij)
3,然后,对encoder中每个位置的隐向量加权求和,得到
c
i
c_i
ci
c
i
=
∑
j
=
1
T
x
α
i
j
h
j
c_i=\sum_{j=1}^{T_x}\alpha_{ij}h_j
ci=j=1∑Txαijhj
在得到上下文向量
c
i
c_i
ci之后,将其与
i
−
1
i-1
i−1时刻的输出
y
i
−
1
y_{i-1}
yi−1在embedding上(改变embedding维度的大小)拼接后,输入到decoder的RNN单元,得到
i
i
i时刻的隐状态
s
i
s_{i}
si,进一步得到输出
o
i
o_i
oi
o
i
=
s
o
f
t
m
a
x
(
W
v
o
c
a
b
s
i
)
o_i=softmax(W_{vocab}s_i)
oi=softmax(Wvocabsi)
Luong Attention
Bahdanau Attention根据 i − 1 i-1 i−1时刻的隐向量 s i − 1 s_{i-1} si−1计算 i i i时刻的上下文 c i c_i ci。Luong 提出了一种新的注意力计算方式,根据当前时刻的隐向量 s i s_{i} si计算 c i c_i ci
论文地址:https://arxiv.org/pdf/1508.04025.pdf
Luong Attention模型架构如下
Luong Attention没有改变经典encoder-decoder结构计算 s i s_{i} si的方式,而是在得到隐状态 s i s_{i} si后进一步计算attention,进而得到attention之后的输出。具体计算公式如下:
1,首先得到
i
i
i时刻decode的隐状态:
s
i
=
f
(
s
i
−
1
,
y
i
−
1
)
s_i=f(s_{i-1},y_{i-1})
si=f(si−1,yi−1)
2,根据
s
i
s_{i}
si计算与encoder中第
j
j
j个位置的匹配度:
e
i
j
=
a
(
s
i
,
h
j
)
e_{ij}=a(s_{i},h_j)
eij=a(si,hj)
3,将
e
i
j
e_{ij}
eij进行softmax归一化,映射到概率空间,得到encoder每个位置的权重
α
i
j
=
e
x
p
(
e
i
j
)
∑
k
=
1
T
x
e
x
p
(
e
i
k
)
\alpha_{ij}=\frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})}
αij=∑k=1Txexp(eik)exp(eij)
4,对encoder中每个位置的隐向量加权求和,得到
c
i
c_i
ci
c
i
=
∑
j
=
1
T
x
α
i
j
h
j
c_i=\sum_{j=1}^{T_x}\alpha_{ij}h_j
ci=j=1∑Txαijhj
5,由
s
i
s_{i}
si和
c
i
c_i
ci得到加入了attention机制的隐状态
s
~
i
\tilde{s}_i
s~i,拼接(改变hidden维度大小)-变换-激活
s
~
i
=
t
a
n
h
(
W
c
[
s
i
,
c
i
]
)
\tilde{s}_i=tanh(Wc[s_{i},c_i])
s~i=tanh(Wc[si,ci])
6,根据
s
~
i
\tilde{s}_i
s~i计算输出
o
i
=
s
o
f
t
m
a
x
(
W
v
o
c
a
b
s
~
i
)
o_i=softmax(W_{vocab}\tilde{s}_i)
oi=softmax(Wvocabs~i)
总结以上两种attention机制,主要区别为:Bahdanau Attention根据 i − 1 i-1 i−1时刻的隐状态计算 i i i时刻的注意力;Luong Attention则根据 i i i时刻的隐状态计算 i i i时刻的注意力
Self Attention
前面两种attention机制都是基于encoder-decoder模型,计算不同时刻decoder输出与encoder之间的关系,即target与source之间的关系。
self-attention机制基于transformer模型,计算不同位置词向量之间的关系
。。。。待完善
attention实现
class BahdanauAttention(nn.Module):
def __init__(self, encode_hidden_size, decode_hidden_size):
super().__init__()
self.W = nn.Linear(decode_hidden_size+2*encode_hidden_size, decode_hidden_size)
self.V = nn.Linear(decode_hidden_size, 1)
def forward(self, query, values, mask):
# query:[1, batch, dec]
# values:[batch, seq_enc, 2*enc]
query = query.permute(1, 0, 2).expand(-1, values.size(1), -1)
mask = torch.unsqueeze(mask, -1)
score = self.V(torch.tanh(self.W(torch.cat((query, values), dim=-1)))) #[batch, seq_enc, 1]
masked_score = score.data.masked_fill(~mask, -1e6)
attention_weights = torch.softmax(masked_score, dim=1)
context_vector = torch.sum(attention_weights*values, dim=1, keepdim=True)
return context_vector
class Decoder(nn.Module):
def __init__(self, embedding_dim, decode_hidden_size, vocab_size):
super().__init__()
self.attention = BahdanauAttention(decode_hidden_size, decode_hidden_size)
self.gru = nn.GRU(embedding_dim+2*decode_hidden_size, decode_hidden_size, bidirectional=False, batch_first=True)
self.out = nn.Linear(decode_hidden_size, vocab_size)
# self.drop = nn.Dropout()
def forward(self, inputs, decode_hidden_state, encode_output, mask):
# inputs:[batch, 1, emb]
# decode_hidden_state:[1, batch, dec]
attention_vector = self.attention(decode_hidden_state, encode_output, mask) #[batch, 1, enc]
inputs = torch.cat((inputs, attention_vector), dim=-1) #[batch, 1, emb+enc]
decode_output, decode_hidden_state = self.gru(inputs, decode_hidden_state) #[batch, 1, dec] | [1, batch, dec]
decode_output = self.out(decode_output) #[batch, 1, vocab]
return decode_output, decode_hidden_state
class LuongAttention(nn.Module):
def __init__(self, encode_hidden_size, decode_hidden_size):
super().__init__()
self.W = nn.Linear(decode_hidden_size+2*encode_hidden_size, decode_hidden_size)
self.V = nn.Linear(decode_hidden_size, 1)
def forward(self, query, values, mask):
# query:[1, batch, dec]
# values:[batch, seq_enc, 2*enc]
query = query.permute(1, 0, 2).expand(-1, values.size(1), -1)
mask = torch.unsqueeze(mask, -1)
score = self.V(torch.tanh(self.W(torch.cat((query, values), dim=-1)))) #[batch, seq_enc, 1]
masked_score = score.data.masked_fill(~mask, -1e6)
attention_weights = torch.softmax(masked_score, dim=1)
context_vector = torch.sum(attention_weights*values, dim=1, keepdim=True)
return context_vector
class Decoder(nn.Module):
def __init__(self, embedding_dim, decode_hidden_size, vocab_size):
super().__init__()
self.attention = LuongAttention(decode_hidden_size, decode_hidden_size)
self.gru = nn.GRU(embedding_dim, decode_hidden_size, bidirectional=False, batch_first=True)
self.fc = nn.Linear(3*decode_hidden_size, decode_hidden_size)
self.out = nn.Linear(decode_hidden_size, vocab_size)
# self.drop = nn.Dropout()
def forward(self, inputs, decode_hidden_state, encode_output, mask):
# inputs:[batch, 1, emb]
# decode_hidden_state:[1, batch, dec]
decode_output, decode_hidden_state = self.gru(inputs, decode_hidden_state) #[batch, 1, dec] | [1, batch, dec]
attention_vector = self.attention(decode_hidden_state, encode_output, mask) # [batch, 1, enc]
decode_output = torch.tanh(self.fc(torch.cat((decode_output, attention_vector), dim=-1)))
decode_output = self.out(decode_output) #[batch, 1, vocab]
return decode_output, decode_hidden_state