《Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks》,本文的作者来自牛津大学。
Transformer借助Self-attention在顺序数据的处理上取得了巨大的成功,本文旨在处理集合数据(其出发点和Pointer Net是一样一样的),并且强调了对于集合的置换不变性。换句话说,模型的输出不依赖于输入的顺序,因此Set Transformer被用于处理集合数据。但是,对于大量数据的集合,自注意力的时间复杂度是
O
(
n
2
)
O(n^2)
O(n2),因此本文又采取了新的方法把时间复杂度降低到了
O
(
m
n
)
O(mn)
O(mn)。
Background
Pooling Architecture for Sets
联想到图池化的操作,置换不变性的模型的一个概述性的式子为:
[1]已经证明过当
p
o
o
l
pool
pool是sum池化
β
β
β,
ϕ
ϕ
ϕ是法人一连续函数的时候,所有的置换不变函数都可以表示成(1)的形式。(1)继续细分可以分为encoder
(
ϕ
)
(ϕ)
(ϕ) 和decoder
β
(
p
o
o
l
(
⋅
)
)
β(pool(·))
β(pool(⋅))两个部分,这就和Transformer的的架构一致。此外[1]还观察到尽管编码器是permutation-equivariant层的堆叠,该模型仍然保持置换不变性。反正甭管怎么说吧,此处为接下来的Set Transformer提供了理论支持。
Attention
在此也复习一下自注意力。
其中
K
Q
KQ
KQ的维度一样,都是
n
v
×
d
q
n_v×d_q
nv×dq的,因此
Q
K
T
QK^T
QKT的结果为
n
v
×
n
v
n_v×n_v
nv×nv的。然后
w
w
w是激活函数,
V
V
V的维度是
n
v
×
d
v
n_v×d_v
nv×dv,因此(3)输出的shape=[
n
v
×
d
v
n_v×d_v
nv×dv]。多头注意力则是多次注意力的拼接之后的线性变换:
其中
O
j
O_j
Oj就是一次注意力之后的结果,为了保证多层堆叠的时候不调节每一层之间的参数,Transformer中强制设置
d
q
M
=
d
q
M
=
d
/
h
d_q^M=d_q^M=d/h
dqM=dqM=d/h,也就是输出的维度和输入的维度一致。
Set Transformer
先用图给出下面几种不同的结构的示意:
首先,定义多头注意力Block(MAB):
H
H
H是输入的特征与其多头自注意力变换之后的特征的和再过一个Norm层。然后,再对
H
H
H做了rFF(row-wise feedforward layer)再和
H
H
H相加。借助MAB,Set Attention Block (SAB)被定义为:
SAB接受一个集合,并在集合中的元素之间执行自我注意,从而产生一个大小相同的集合。但是与Transformer不同的是,它缺少了position embedding以及dropout。由于SAB的输出包含关于输入集合X中元素之间成对交互的信息,因此可以和transformer堆叠多个SAB来编码更高阶的交互,这就是简化版本的transformer。之后,由于时间复杂度的关系,提出了一个简化的版本ISAB:
相比于SAB,注意力不在对两两输入节点(数据)计算,ISAB首先根据输入集将I转换成H,这类似于低秩投影或自编码器模型,其中输入(X)首先投影到低维对象(H),然后重构产生输出。
I
I
I的大小为
m
m
m,这是需要调节的超参数。
Pooling by Multihead Attention
上述的SAB的多层堆叠可以看做是encoder的过程。之后,对encoder的输出
Z
∈
R
n
∗
d
Z∈R^{n*d}
Z∈Rn∗d进行解码:
这里通过对一组可学习的k个种子向量S应用多头注意力来聚合特征,因此
P
M
A
k
PMA_k
PMAk的输出是
k
k
k个元素的集合。在大多数情况下,我们使用一个种子向量(k = 1),但对于需要k个相关输出的平摊聚类这样的问题,自然的做法是使用k个种子向量。而为了模拟这
k
k
k个种子之间的交互,又用了一个SAB(真不愧是Attention is all you need):
所以整体的模型的框架为:
至此模型的定义结束。实验我还就来一手不写。
Code!
来看代码。github地址为:https://github.com/juho-lee/set_transformer。
class SetTransformer(nn.Module):
def __init__(self, dim_input, num_outputs, dim_output,
num_inds=32, dim_hidden=128, num_heads=4, ln=False):
super(SetTransformer, self).__init__()
self.enc = nn.Sequential(
ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=ln),
ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=ln))
self.dec = nn.Sequential(
PMA(dim_hidden, num_heads, num_outputs, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
SAB(dim_hidden, dim_hidden, num_heads, ln=ln),
nn.Linear(dim_hidden, dim_output))
def forward(self, X):
return self.dec(self.enc(X))
可以看到,encoder由两层ISAB组成,decoder由PMA以及两层SAB组成。接下来拆解看不同的Block:
MAB:
class MAB(nn.Module):
def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads # 多头注意力头数
self.fc_q = nn.Linear(dim_Q, dim_V) # 自注意力KQV参数
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V) # 对应公式(4)中的输出参数O
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), 0)
K_ = torch.cat(K.split(dim_split, 2), 0)
V_ = torch.cat(V.split(dim_split, 2), 0)
# 注意此处是采用了Transformer原文中的自注意力的定义:softmax(QK^T/√d)V的方式,与本论文中的公式稍有不同
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O)) # 对应公式(6)
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
SAB:。和论文中描述的一样,SAB只是输入都是X的MAB
class SAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, ln=False):
super(SAB, self).__init__()
self.mab = MAB(dim_in, dim_in, dim_out, num_heads, ln=ln)
def forward(self, X):
return self.mab(X, X)
ISAB:,相比于SAB只是多了一个根据预定义参数 I I I求H的操作:
class ISAB(nn.Module):
def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) # 公式(10)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) # 公式(9)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
PMA:,根据一组seeds计算多个MAB,对应文中的公式(11)
class PMA(nn.Module):
def __init__(self, dim, num_heads, num_seeds, ln=False):
super(PMA, self).__init__()
self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
nn.init.xavier_uniform_(self.S)
self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
def forward(self, X):
return self.mab(self.S.repeat(X.size(0), 1, 1), X)
参考文献
Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. In Advances in Neural Information Processing Systems (NeurIPS), 2017.