文章目录
1. Title
Self-Attention with Relative Position Representations
https://github.com/evelinehong/Transformer_Relative_Position_PyTorch
2. Summary
Transformer的核心结构Self-Attention机制由于其无法对输入token的相对位置或绝对位置信息进行建模,因此,目前主流的方案都是在输入token之外再额外加上一个Positional Encoding来引入位置信息。本文则是从Self-Attention机制内部出发,通过在计算过程中引入token之间的相对位置关系向量,打破了Self-Attention机制的Permutation-Invariant特性,从而更高效地完成了位置信息的编码,性能得到了提升。
阅读本文主要是在阅读Vision Transformer相关论文中看到了相关应用,Relative Positional Encoding在CV领域也有很多应用,对Vision Transformer性能的提升也是比较明显的。
3. Problem Statement
不同于RNN、CNN,Transformer结构没有显式对相对或者绝对位置进行建模的能力,为此,目前常见的做法是输入中额外添加包含位置信息的特征表示。
但是本文则是从另一个角度出发,Transformer之所以无法对相对或者绝对位置建模,是因为其核心操作Self-Attention是Permutation Invariant,这个性质的简单说明可以参见我另一篇博客:Conditional Positional Encodings for Vision Transformers。
因此,倘若能够打破Self-Attention操作的Permutation Invariant特性,即可不再需要额外的位置信息的输入。
4. Method(s)
4.1 Relation-aware Self-Attention
将输入看做是一个带有标签的有向全连接图。
对于两个输入元素
x
i
x_i
xi和
x
j
x_j
xj之间的边通过两个向量来表示
a
i
j
V
,
a
i
j
K
∈
R
d
a
a_{i j}^{V}, a_{i j}^{K} \in \mathbb{R}^{d_{a}}
aijV,aijK∈Rda,这些向量表示在多个head之间共享,
d
a
=
d
z
d_a=d_z
da=dz。通过引入边的特征表示,原始的Self-Attention机制修改为以下计算方式:
z
i
=
∑
j
=
1
n
α
i
j
(
x
j
W
V
+
a
i
j
V
)
z_{i}=\sum_{j=1}^{n} \alpha_{i j}\left(x_{j} W^{V}+a_{i j}^{V}\right)
zi=j=1∑nαij(xjWV+aijV)
α
i
j
=
exp
e
i
j
∑
k
=
1
n
exp
e
i
k
\alpha_{i j}=\frac{\exp e_{i j}}{\sum_{k=1}^{n} \exp e_{i k}}
αij=∑k=1nexpeikexpeij
e
i
j
=
x
i
W
Q
(
x
j
W
K
+
a
i
j
K
)
T
d
z
e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}+a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}}
eij=dzxiWQ(xjWK+aijK)T
即对于各个Value和Key来说,都会引入一个相互的位置关系表示,从而打破了Self-Attention的Permutation-Invariant。
4.2 Relative Position Representation
考虑到计算量、内存消耗以及远距离的精确位置信息效用不是很足等因素,本文对最远的Relative Position Distance限制为
k
k
k。
a
i
j
K
=
w
c
l
i
p
(
j
−
i
,
k
)
K
a
i
j
V
=
w
c
l
i
p
(
j
−
i
,
k
)
V
clip
(
x
,
k
)
=
max
(
−
k
,
min
(
k
,
x
)
)
\begin{aligned} a_{i j}^{K} &=w_{\mathrm{clip}(j-i, k)}^{K} \\ a_{i j}^{V} &=w_{\mathrm{clip}(j-i, k)}^{V} \\ \operatorname{clip}(x, k) &=\max (-k, \min (k, x)) \end{aligned}
aijKaijVclip(x,k)=wclip(j−i,k)K=wclip(j−i,k)V=max(−k,min(k,x))
在这种设定下,仅需要学习 w K = ( w − k K , … , w k K ) w^{K}=\left(w_{-k}^{K}, \ldots, w_{k}^{K}\right) wK=(w−kK,…,wkK)和 w V = ( w − k V , … , w k V ) w^{V}=\left(w_{-k}^{V}, \ldots, w_{k}^{V}\right) wV=(w−kV,…,wkV)。
下面结合https://github.com/evelinehong/Transformer_Relative_Position_PyTorch这份代码,对这个部分进行更详细地阐述。
class RelativePosition(nn.Module):
def __init__(self, num_units, max_relative_position):
super().__init__()
self.num_units = num_units
self.max_relative_position = max_relative_position
self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
nn.init.xavier_uniform_(self.embeddings_table)
def forward(self, length_q, length_k):
range_vec_q = torch.arange(length_q)
range_vec_k = torch.arange(length_k)
distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
final_mat = distance_mat_clipped + self.max_relative_position
final_mat = torch.LongTensor(final_mat)
embeddings = self.embeddings_table[final_mat]
return embeddings
代码说明见下图:
4.3 Efficient Implementation
对于一个长度为
n
n
n和一个head数为
h
h
h的Multi-Head Self-Attention来说,通过在多个head之间共享Relative Position Representation,使得其空间复杂度由
O
(
h
n
2
d
a
)
O(hn^2d_a)
O(hn2da)下降至
O
(
n
2
d
a
)
O(n^2d_a)
O(n2da),同时,不同Sequence之间也可以进行共享。
因此,对于一个batchsize为
b
b
b的序列来说,其空间复杂度由
O
(
b
h
n
d
z
)
O(bhnd_z)
O(bhndz)上升为
O
(
b
h
n
d
z
+
n
2
d
a
)
O(bhnd_z+n^2d_a)
O(bhndz+n2da),其中
O
(
n
2
d
a
)
O(n^2d_a)
O(n2da)为
b
b
b个Sequence的Relative Position Representation所带来的额外空间消耗。
当没有Relative Position Representation时,
e
i
j
e_{ij}
eij可以通过
b
h
bh
bh个并行化的
n
×
d
z
n \times d_z
n×dz和
d
z
×
n
d_z \times n
dz×n进行矩阵乘法高效得到。这种高效计算的简单推导如下图:
当加入Relative Positional Representation之后,上述高效计算的前提就被打破了,
e
i
j
e_{ij}
eij的计算不能分解为
q
i
q_i
qi和
k
j
k_j
kj两个独立的部分了,而是
q
i
q_i
qi和
k
i
j
k_{ij}
kij两个不完全独立的部分,此时无法直接将其转化为高效的矩阵计算。
为了解决这个问题,作者将
k
i
j
k_{ij}
kij部分拆开,将其分为两个部分分开计算,每个部分可以独立采用一个并行化的高效计算矩阵运算来完成:
e
i
j
=
x
i
W
Q
(
x
j
W
K
)
T
+
x
i
W
Q
(
a
i
j
K
)
T
d
z
e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}\right)^{T}+x_{i} W^{Q}\left(a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}}
eij=dzxiWQ(xjWK)T+xiWQ(aijK)T
上式中,第一部分与未加入Relative Positional Representation时计算方式一样,第二部分则采用稍微不太一样的矩阵计算来完成:
记上式右侧部分为
e
i
j
′
e_{ij}'
eij′,记
x
i
W
Q
x_iW^Q
xiWQ为
q
i
q_i
qi,记
a
i
j
K
a_{ij}^K
aijK为
k
i
j
k_{ij}
kij,忽略分母项,则右侧部分可表示为
e
i
j
′
=
q
i
k
i
j
T
e_{ij}'=q_ik_{ij}^T
eij′=qikijT。
- 我们一共有 b h × n bh \times n bh×n个 q i q_i qi,每个 q i q_i qi的维度为 d z d_z dz
- 同样我们一共有 n × n n \times n n×n个 k i j k_{ij} kij,每个 k i j k_{ij} kij的维度为 d z d_z dz
为了能够进行高效的矩阵计算,我们需要将 q i q_i qi和 k i j k_{ij} kij进行重新解释(reshape):
- q i q_i qi也可以表示为我们一共有 n × b h n \times bh n×bh个 q i ′ q_i' qi′,每个 q i ′ q_i' qi′的维度为 d z d_z dz
- k i j k_{ij} kij也可以表示为我们一共有 n × n n \times n n×n个 k i j ′ k_{ij}' kij′,每个 k i j ′ k_{ij}' kij′的维度为 d z d_z dz(含义没有发生变化)
此时我们便可以对 q i ′ q_i' qi′和 k i j ′ k_{ij}' kij′进行 n n n个并行化的两个大小为 b h × d z bh \times d_z bh×dz和 d z × n d_z \times n dz×n的矩阵计算来加速计算。最终再重新reshape回原始的大小即可完成 e i j e_{ij} eij两个部分的高效并行化计算。
具体可以参见以下代码:
class MultiHeadAttentionLayer(nn.Module):
def __init__(self, hid_dim, n_heads, dropout, device):
super().__init__()
assert hid_dim % n_heads == 0
self.hid_dim = hid_dim
self.n_heads = n_heads
self.head_dim = hid_dim // n_heads
self.max_relative_position = 2
self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)
self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)
self.fc_q = nn.Linear(hid_dim, hid_dim)
self.fc_k = nn.Linear(hid_dim, hid_dim)
self.fc_v = nn.Linear(hid_dim, hid_dim)
self.fc_o = nn.Linear(hid_dim, hid_dim)
self.dropout = nn.Dropout(dropout)
self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)
def forward(self, query, key, value, mask=None):
# query = [batch size, query len, hid dim]
# key = [batch size, key len, hid dim]
# value = [batch size, value len, hid dim]
batch_size = query.shape[0]
len_k = key.shape[1]
len_q = query.shape[1]
len_v = value.shape[1]
# get q k v
query = self.fc_q(query) # b n d
key = self.fc_k(key) # b n d
value = self.fc_v(value) # b n d
r_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # b h n d/h
r_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3) # b h n d/h
attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2)) # first item of equal (5) b h n n
r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.n_heads, self.head_dim) # n b*h d/h
r_k2 = self.relative_position_k(len_q, len_k) # n n d/h
attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1) # b*h n n
attn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k) # second item of equal (5) b h n n
attn = (attn1 + attn2) / self.scale
if mask is not None:
attn = attn.masked_fill(mask == 0, -1e10)
attn = self.dropout(torch.softmax(attn, dim=-1))
# attn = [batch size, n heads, query len, key len]
r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
weight1 = torch.matmul(attn, r_v1)
r_v2 = self.relative_position_v(len_q, len_v)
weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.n_heads, len_k)
weight2 = torch.matmul(weight2, r_v2)
weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)
x = weight1 + weight2
# x = [batch size, n heads, query len, head dim]
x = x.permute(0, 2, 1, 3).contiguous()
# x = [batch size, query len, n heads, head dim]
x = x.view(batch_size, -1, self.hid_dim)
# x = [batch size, query len, hid dim]
x = self.fc_o(x)
# x = [batch size, query len, hid dim]
return x
5. Evaluation
本篇论文主要是用于NLP领域,其实验结果如下:
6. Conclusion
本文主要是从Self-Attention机制本身出发,在计算过程中引入了相对位置信息,从而打破了Self-Attention的Permutation-Invariant特性,提升了各个word之间关系构建能力。