2018-Self-Attention with Relative Position Representations

1. Title

Self-Attention with Relative Position Representations
https://github.com/evelinehong/Transformer_Relative_Position_PyTorch

2. Summary

Transformer的核心结构Self-Attention机制由于其无法对输入token的相对位置或绝对位置信息进行建模,因此,目前主流的方案都是在输入token之外再额外加上一个Positional Encoding来引入位置信息。本文则是从Self-Attention机制内部出发,通过在计算过程中引入token之间的相对位置关系向量,打破了Self-Attention机制的Permutation-Invariant特性,从而更高效地完成了位置信息的编码,性能得到了提升。
阅读本文主要是在阅读Vision Transformer相关论文中看到了相关应用,Relative Positional Encoding在CV领域也有很多应用,对Vision Transformer性能的提升也是比较明显的。

3. Problem Statement

不同于RNN、CNN,Transformer结构没有显式对相对或者绝对位置进行建模的能力,为此,目前常见的做法是输入中额外添加包含位置信息的特征表示。
但是本文则是从另一个角度出发,Transformer之所以无法对相对或者绝对位置建模,是因为其核心操作Self-Attention是Permutation Invariant,这个性质的简单说明可以参见我另一篇博客:Conditional Positional Encodings for Vision Transformers
因此,倘若能够打破Self-Attention操作的Permutation Invariant特性,即可不再需要额外的位置信息的输入。

4. Method(s)

4.1 Relation-aware Self-Attention

将输入看做是一个带有标签的有向全连接图。
对于两个输入元素 x i x_i xi x j x_j xj之间的边通过两个向量来表示 a i j V , a i j K ∈ R d a a_{i j}^{V}, a_{i j}^{K} \in \mathbb{R}^{d_{a}} aijV,aijKRda,这些向量表示在多个head之间共享, d a = d z d_a=d_z da=dz。通过引入边的特征表示,原始的Self-Attention机制修改为以下计算方式:
z i = ∑ j = 1 n α i j ( x j W V + a i j V ) z_{i}=\sum_{j=1}^{n} \alpha_{i j}\left(x_{j} W^{V}+a_{i j}^{V}\right) zi=j=1nαij(xjWV+aijV) α i j = exp ⁡ e i j ∑ k = 1 n exp ⁡ e i k \alpha_{i j}=\frac{\exp e_{i j}}{\sum_{k=1}^{n} \exp e_{i k}} αij=k=1nexpeikexpeij e i j = x i W Q ( x j W K + a i j K ) T d z e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}+a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}} eij=dz xiWQ(xjWK+aijK)T
即对于各个Value和Key来说,都会引入一个相互的位置关系表示,从而打破了Self-Attention的Permutation-Invariant。

4.2 Relative Position Representation

考虑到计算量、内存消耗以及远距离的精确位置信息效用不是很足等因素,本文对最远的Relative Position Distance限制为 k k k
a i j K = w c l i p ( j − i , k ) K a i j V = w c l i p ( j − i , k ) V clip ⁡ ( x , k ) = max ⁡ ( − k , min ⁡ ( k , x ) ) \begin{aligned} a_{i j}^{K} &=w_{\mathrm{clip}(j-i, k)}^{K} \\ a_{i j}^{V} &=w_{\mathrm{clip}(j-i, k)}^{V} \\ \operatorname{clip}(x, k) &=\max (-k, \min (k, x)) \end{aligned} aijKaijVclip(x,k)=wclip(ji,k)K=wclip(ji,k)V=max(k,min(k,x))

在这种设定下,仅需要学习 w K = ( w − k K , … , w k K ) w^{K}=\left(w_{-k}^{K}, \ldots, w_{k}^{K}\right) wK=(wkK,,wkK) w V = ( w − k V , … , w k V ) w^{V}=\left(w_{-k}^{V}, \ldots, w_{k}^{V}\right) wV=(wkV,,wkV)

下面结合https://github.com/evelinehong/Transformer_Relative_Position_PyTorch这份代码,对这个部分进行更详细地阐述。

class RelativePosition(nn.Module):

    def __init__(self, num_units, max_relative_position):
        super().__init__()
        self.num_units = num_units
        self.max_relative_position = max_relative_position
        self.embeddings_table = nn.Parameter(torch.Tensor(max_relative_position * 2 + 1, num_units))
        nn.init.xavier_uniform_(self.embeddings_table)

    def forward(self, length_q, length_k):
        range_vec_q = torch.arange(length_q)
        range_vec_k = torch.arange(length_k)
        distance_mat = range_vec_k[None, :] - range_vec_q[:, None]
        distance_mat_clipped = torch.clamp(distance_mat, -self.max_relative_position, self.max_relative_position)
        final_mat = distance_mat_clipped + self.max_relative_position
        final_mat = torch.LongTensor(final_mat)
        embeddings = self.embeddings_table[final_mat]

        return embeddings

代码说明见下图:
在这里插入图片描述

4.3 Efficient Implementation

对于一个长度为 n n n和一个head数为 h h h的Multi-Head Self-Attention来说,通过在多个head之间共享Relative Position Representation,使得其空间复杂度由 O ( h n 2 d a ) O(hn^2d_a) O(hn2da)下降至 O ( n 2 d a ) O(n^2d_a) O(n2da),同时,不同Sequence之间也可以进行共享。
因此,对于一个batchsize为 b b b的序列来说,其空间复杂度由 O ( b h n d z ) O(bhnd_z) O(bhndz)上升为 O ( b h n d z + n 2 d a ) O(bhnd_z+n^2d_a) O(bhndz+n2da),其中 O ( n 2 d a ) O(n^2d_a) O(n2da) b b b个Sequence的Relative Position Representation所带来的额外空间消耗。

当没有Relative Position Representation时, e i j e_{ij} eij可以通过 b h bh bh个并行化的 n × d z n \times d_z n×dz d z × n d_z \times n dz×n进行矩阵乘法高效得到。这种高效计算的简单推导如下图:
在这里插入图片描述
当加入Relative Positional Representation之后,上述高效计算的前提就被打破了, e i j e_{ij} eij的计算不能分解为 q i q_i qi k j k_j kj两个独立的部分了,而是 q i q_i qi k i j k_{ij} kij两个不完全独立的部分,此时无法直接将其转化为高效的矩阵计算。

为了解决这个问题,作者将 k i j k_{ij} kij部分拆开,将其分为两个部分分开计算,每个部分可以独立采用一个并行化的高效计算矩阵运算来完成:
e i j = x i W Q ( x j W K ) T + x i W Q ( a i j K ) T d z e_{i j}=\frac{x_{i} W^{Q}\left(x_{j} W^{K}\right)^{T}+x_{i} W^{Q}\left(a_{i j}^{K}\right)^{T}}{\sqrt{d_{z}}} eij=dz xiWQ(xjWK)T+xiWQ(aijK)T

上式中,第一部分与未加入Relative Positional Representation时计算方式一样,第二部分则采用稍微不太一样的矩阵计算来完成:
记上式右侧部分为 e i j ′ e_{ij}' eij,记 x i W Q x_iW^Q xiWQ q i q_i qi,记 a i j K a_{ij}^K aijK k i j k_{ij} kij,忽略分母项,则右侧部分可表示为 e i j ′ = q i k i j T e_{ij}'=q_ik_{ij}^T eij=qikijT

  • 我们一共有 b h × n bh \times n bh×n q i q_i qi,每个 q i q_i qi的维度为 d z d_z dz
  • 同样我们一共有 n × n n \times n n×n k i j k_{ij} kij,每个 k i j k_{ij} kij的维度为 d z d_z dz

为了能够进行高效的矩阵计算,我们需要将 q i q_i qi k i j k_{ij} kij进行重新解释(reshape):

  • q i q_i qi也可以表示为我们一共有 n × b h n \times bh n×bh q i ′ q_i' qi,每个 q i ′ q_i' qi的维度为 d z d_z dz
  • k i j k_{ij} kij也可以表示为我们一共有 n × n n \times n n×n k i j ′ k_{ij}' kij,每个 k i j ′ k_{ij}' kij的维度为 d z d_z dz(含义没有发生变化)

此时我们便可以对 q i ′ q_i' qi k i j ′ k_{ij}' kij进行 n n n个并行化的两个大小为 b h × d z bh \times d_z bh×dz d z × n d_z \times n dz×n的矩阵计算来加速计算。最终再重新reshape回原始的大小即可完成 e i j e_{ij} eij两个部分的高效并行化计算。

具体可以参见以下代码:

class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout, device):
        super().__init__()

        assert hid_dim % n_heads == 0

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads
        self.max_relative_position = 2

        self.relative_position_k = RelativePosition(self.head_dim, self.max_relative_position)
        self.relative_position_v = RelativePosition(self.head_dim, self.max_relative_position)

        self.fc_q = nn.Linear(hid_dim, hid_dim)
        self.fc_k = nn.Linear(hid_dim, hid_dim)
        self.fc_v = nn.Linear(hid_dim, hid_dim)

        self.fc_o = nn.Linear(hid_dim, hid_dim)

        self.dropout = nn.Dropout(dropout)

        self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(self, query, key, value, mask=None):
        # query = [batch size, query len, hid dim]
        # key = [batch size, key len, hid dim]
        # value = [batch size, value len, hid dim]
        batch_size = query.shape[0]
        len_k = key.shape[1]
        len_q = query.shape[1]
        len_v = value.shape[1]

        # get q k v
        query = self.fc_q(query)  # b n d
        key = self.fc_k(key)  # b n d
        value = self.fc_v(value)  # b n d

        r_q1 = query.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)  # b h n d/h
        r_k1 = key.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)  # b h n d/h
        attn1 = torch.matmul(r_q1, r_k1.permute(0, 1, 3, 2))  # first item of equal (5) b h n n

        r_q2 = query.permute(1, 0, 2).contiguous().view(len_q, batch_size * self.n_heads, self.head_dim)  # n b*h d/h
        r_k2 = self.relative_position_k(len_q, len_k)  # n n d/h
        attn2 = torch.matmul(r_q2, r_k2.transpose(1, 2)).transpose(0, 1)  # b*h n n
        attn2 = attn2.contiguous().view(batch_size, self.n_heads, len_q, len_k)  # second item of equal (5) b h n n
        attn = (attn1 + attn2) / self.scale 

        if mask is not None:
            attn = attn.masked_fill(mask == 0, -1e10)

        attn = self.dropout(torch.softmax(attn, dim=-1))

        # attn = [batch size, n heads, query len, key len]
        r_v1 = value.view(batch_size, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        weight1 = torch.matmul(attn, r_v1)
        r_v2 = self.relative_position_v(len_q, len_v)
        weight2 = attn.permute(2, 0, 1, 3).contiguous().view(len_q, batch_size * self.n_heads, len_k)
        weight2 = torch.matmul(weight2, r_v2)
        weight2 = weight2.transpose(0, 1).contiguous().view(batch_size, self.n_heads, len_q, self.head_dim)

        x = weight1 + weight2

        # x = [batch size, n heads, query len, head dim]

        x = x.permute(0, 2, 1, 3).contiguous()

        # x = [batch size, query len, n heads, head dim]

        x = x.view(batch_size, -1, self.hid_dim)

        # x = [batch size, query len, hid dim]

        x = self.fc_o(x)

        # x = [batch size, query len, hid dim]

        return x

5. Evaluation

本篇论文主要是用于NLP领域,其实验结果如下:
在这里插入图片描述在这里插入图片描述在这里插入图片描述

6. Conclusion

本文主要是从Self-Attention机制本身出发,在计算过程中引入了相对位置信息,从而打破了Self-Attention的Permutation-Invariant特性,提升了各个word之间关系构建能力。

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值