文章目录
所解决的问题
深度学习中很关键的一环就是学习数据的表示,但是所设计的网络很多时候都是固定数据的输入维度 1,然而对于另外一些场景,像多示例学习(Multiple instance learning
),输入是一个实例集合(a set of instances),label与整个集合相关。
这里有两个关键点,一个是置换不变(permutation invariant
),另一个是输入的大小可以任意。而这两点对于传统的神经网络是较难去处理的,RNN系列虽然可以处理输入大小不一致的情况,但是对输入序列的顺序会比较敏感。
背景
前人也有一些工作像set pooling的方法:In this model, each element in a set is first independently fed into a feed-forward neural network that takes fixed-size inputs. Resulting feature-space embeddings are then aggregated using a pooling operation (mean, sum, max or similar). The final output is obtained by further non-linear processing of the aggregated embedding.
并且被理论证明了可以近似任意set function
。(Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. In Advances in Neural Information Processing Systems (NeurIPS), 2017.)
但是并没有很清楚地去描述样本之间的相互关联程度,因为样本是单独输入进网络中的。
这种learn a parametric mapping from an input set of points to the centers of clusters of points inside the set.的方法并不会说去考虑overlapping subsets of the input set。
The main difficulty is that the parametric mapping must assign each point to its corresponding cluster while modelling the explaining away pattern such that the resulting clusters do not attempt to explain overlapping subsets of the input set.
Due to this innate difficulty, clustering is typically solved via iterative algorithms that refine randomly initialized clusters until convergence. Even though a neural network with a set poling operation can approximate such an amortized mapping by learning to quantize space, a crucial shortcoming is that this quantization cannot depend on the contents of the set. This limits the quality of the solution and also may make optimization of such a model more difficult。
所采用的方法?
作者提出了a novel set-input deep neural network architecture called the Set Transformer。三点创新:
1.self-attention mechaism
处理set中的每个element,encode pairwise- or higher-order interactions between elements in the set。
2. 将Transformer
的算法复杂度从
O
(
n
2
)
O(n^{2})
O(n2)降到
O
(
n
m
)
O(nm)
O(nm)。
3. self-attention mechaism
处理样本与样本之间的特征聚合问题,这一点和matching network的思路是一致的。因为聚类的话,很大程度是需要考虑其它样本的。
Pooling Architecture for Sets
一个set
中的target value
是相同的,因此这个set
中的样本的顺序是忽略不计的。一个简单的置换不变(permutation model
)模型是在embedding
之后经过一个pooling
:
net ( { x 1 , … , x n } ) = ρ ( pool ( { ϕ ( x 1 ) , … , ϕ ( x n ) } ) ) \operatorname{net}\left(\left\{x_{1}, \ldots, x_{n}\right\}\right)=\rho\left(\operatorname{pool}\left(\left\{\phi\left(x_{1}\right), \ldots, \phi\left(x_{n}\right)\right\}\right)\right) net({x1,…,xn})=ρ(pool({ϕ(x1),…,ϕ(xn)}))
pool
是一个求和的操作。
ϕ
\phi
ϕ可以是任意的连续函数,上述这个公式可以看作是两部分,一个encoder
ϕ
\phi
ϕ, 独立处理n
个items
中的每个element
。decoder
ρ
(
p
o
o
l
(
⋅
)
)
\rho(pool(·))
ρ(pool(⋅))聚合这些encoder
的特征产生期望输出。
Attention
假设有n query vectors
(与set
中的n
个elements
相关),维度为
d
q
:
Q
∈
R
n
×
d
q
d_{q}:Q \in \mathbb{R}^{n \times d_{q}}
dq:Q∈Rn×dq,一个attention
A
t
t
(
Q
,
K
,
V
)
Att(Q,K,V)
Att(Q,K,V)使用
n
v
n_{v}
nv key-value pairs
K
∈
R
n
v
×
d
q
K \in \mathbb{R}^{n_{v} \times d_{q}}
K∈Rnv×dq,
V
∈
R
n
v
×
d
v
V \in \mathbb{R}^{n_{v} \times d_{v}}
V∈Rnv×dv:
Att ( Q , K , V ; ω ) = ω ( Q K ⊤ ) V \operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V Att(Q,K,V;ω)=ω(QK⊤)V
Q
K
⊤
∈
R
n
×
n
v
Q K^{\top} \in \mathbb{R^{n \times n_{v}}}
QK⊤∈Rn×nv衡量每对query
和key
向量的相似程度。之后用
w
w
w对其加权求和。扩展到Multi-head attention
:
Multihead ( Q , K , V ; λ , ω ) = concat ( O 1 , ⋯ , O h ) W O where O j = Att ( Q W j Q , K W j K , V W j V ; ω j ) \begin{array}{l} \text { Multihead }(Q, K, V ; \lambda, \omega)=\operatorname{concat}\left(O_{1}, \cdots, O_{h}\right) W^{O} \\ \text { where } O_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right) \end{array} Multihead (Q,K,V;λ,ω)=concat(O1,⋯,Oh)WO where Oj=Att(QWjQ,KWjK,VWjV;ωj)
Set Transformer
这里主要设计set transformer
的结构:
Permutation Equivariant (Induced) Set Attention Blocks
提出SAB
和ISAB
, use self-attention to concurrently encode the whole set。
MAB ( X , Y ) = LayerNorm ( H + rFF ( H ) ) where H = LayerNorm ( X + Multihead ( X , Y , Y ; ω ) ) \begin{array}{l} \operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H)) \\ \text { where } H=\text { LayerNorm }(X+\text { Multihead }(X, Y, Y ; \omega)) \end{array} MAB(X,Y)=LayerNorm(H+rFF(H)) where H= LayerNorm (X+ Multihead (X,Y,Y;ω))
rFF is any row-wise feedforward layer, LayerNorm is layer normalization, The MAB is an adaptation of the encoder block of the Transformer without positional encoding and dropout.
基于MAB
定义SAB
:
SAB ( X ) : = MAB ( X , X ) \operatorname{SAB}(X):=\operatorname{MAB}(X, X) SAB(X):=MAB(X,X)
SAB
的时间复杂度为
O
(
n
2
)
O(n^{2})
O(n2),当set
(
n
≫
1
)
(n \gg 1)
(n≫1)较大时这个计算开销也很大。因此提出了Induced Set Attention Block (ISAB)
: