【ICML2019】Set Transformer:置换不变的注意力神经网络框架

所解决的问题

  深度学习中很关键的一环就是学习数据的表示,但是所设计的网络很多时候都是固定数据的输入维度 1,然而对于另外一些场景,像多示例学习(Multiple instance learning),输入是一个实例集合(a set of instances),label与整个集合相关。

  这里有两个关键点,一个是置换不变(permutation invariant),另一个是输入的大小可以任意。而这两点对于传统的神经网络是较难去处理的,RNN系列虽然可以处理输入大小不一致的情况,但是对输入序列的顺序会比较敏感。

背景

  前人也有一些工作像set pooling的方法:In this model, each element in a set is first independently fed into a feed-forward neural network that takes fixed-size inputs. Resulting feature-space embeddings are then aggregated using a pooling operation (mean, sum, max or similar). The final output is obtained by further non-linear processing of the aggregated embedding.

  并且被理论证明了可以近似任意set function。(Zaheer, M., Kottur, S., Ravanbakhsh, S., Poczos, B., Salakhutdinov, R. R., and Smola, A. J. Deep sets. In Advances in Neural Information Processing Systems (NeurIPS), 2017.)

  但是并没有很清楚地去描述样本之间的相互关联程度,因为样本是单独输入进网络中的。

  这种learn a parametric mapping from an input set of points to the centers of clusters of points inside the set.的方法并不会说去考虑overlapping subsets of the input set。

  The main difficulty is that the parametric mapping must assign each point to its corresponding cluster while modelling the explaining away pattern such that the resulting clusters do not attempt to explain overlapping subsets of the input set.

  Due to this innate difficulty, clustering is typically solved via iterative algorithms that refine randomly initialized clusters until convergence. Even though a neural network with a set poling operation can approximate such an amortized mapping by learning to quantize space, a crucial shortcoming is that this quantization cannot depend on the contents of the set. This limits the quality of the solution and also may make optimization of such a model more difficult。

所采用的方法?

  作者提出了a novel set-input deep neural network architecture called the Set Transformer。三点创新:

1.self-attention mechaism处理set中的每个element,encode pairwise- or higher-order interactions between elements in the set。
2. 将Transformer的算法复杂度从 O ( n 2 ) O(n^{2}) O(n2)降到 O ( n m ) O(nm) O(nm)
3. self-attention mechaism处理样本与样本之间的特征聚合问题,这一点和matching network的思路是一致的。因为聚类的话,很大程度是需要考虑其它样本的。

Pooling Architecture for Sets

  一个set中的target value是相同的,因此这个set中的样本的顺序是忽略不计的。一个简单的置换不变(permutation model)模型是在embedding之后经过一个pooling

net ⁡ ( { x 1 , … , x n } ) = ρ ( pool ⁡ ( { ϕ ( x 1 ) , … , ϕ ( x n ) } ) ) \operatorname{net}\left(\left\{x_{1}, \ldots, x_{n}\right\}\right)=\rho\left(\operatorname{pool}\left(\left\{\phi\left(x_{1}\right), \ldots, \phi\left(x_{n}\right)\right\}\right)\right) net({x1,,xn})=ρ(pool({ϕ(x1),,ϕ(xn)}))

  pool是一个求和的操作。 ϕ \phi ϕ可以是任意的连续函数,上述这个公式可以看作是两部分,一个encoder ϕ \phi ϕ, 独立处理nitems中的每个elementdecoder ρ ( p o o l ( ⋅ ) ) \rho(pool(·)) ρ(pool())聚合这些encoder的特征产生期望输出。

Attention

  假设有n query vectors(与set中的nelements相关),维度为 d q : Q ∈ R n × d q d_{q}:Q \in \mathbb{R}^{n \times d_{q}} dq:QRn×dq,一个attention A t t ( Q , K , V ) Att(Q,K,V) Att(Q,K,V)使用 n v n_{v} nv key-value pairs K ∈ R n v × d q K \in \mathbb{R}^{n_{v} \times d_{q}} KRnv×dq, V ∈ R n v × d v V \in \mathbb{R}^{n_{v} \times d_{v}} VRnv×dv

Att ⁡ ( Q , K , V ; ω ) = ω ( Q K ⊤ ) V \operatorname{Att}(Q, K, V ; \omega)=\omega\left(Q K^{\top}\right) V Att(Q,K,V;ω)=ω(QK)V

   Q K ⊤ ∈ R n × n v Q K^{\top} \in \mathbb{R^{n \times n_{v}}} QKRn×nv衡量每对querykey向量的相似程度。之后用 w w w对其加权求和。扩展到Multi-head attention

 Multihead  ( Q , K , V ; λ , ω ) = concat ⁡ ( O 1 , ⋯   , O h ) W O  where  O j = Att ⁡ ( Q W j Q , K W j K , V W j V ; ω j ) \begin{array}{l} \text { Multihead }(Q, K, V ; \lambda, \omega)=\operatorname{concat}\left(O_{1}, \cdots, O_{h}\right) W^{O} \\ \text { where } O_{j}=\operatorname{Att}\left(Q W_{j}^{Q}, K W_{j}^{K}, V W_{j}^{V} ; \omega_{j}\right) \end{array}  Multihead (Q,K,V;λ,ω)=concat(O1,,Oh)WO where Oj=Att(QWjQ,KWjK,VWjV;ωj)

Set Transformer

  这里主要设计set transformer的结构:

Permutation Equivariant (Induced) Set Attention Blocks

  提出SABISAB, use self-attention to concurrently encode the whole set。

MAB ⁡ ( X , Y ) = LayerNorm ⁡ ( H + rFF ⁡ ( H ) )  where  H =  LayerNorm  ( X +  Multihead  ( X , Y , Y ; ω ) ) \begin{array}{l} \operatorname{MAB}(X, Y)=\operatorname{LayerNorm}(H+\operatorname{rFF}(H)) \\ \text { where } H=\text { LayerNorm }(X+\text { Multihead }(X, Y, Y ; \omega)) \end{array} MAB(X,Y)=LayerNorm(H+rFF(H)) where H= LayerNorm (X+ Multihead (X,Y,Y;ω))

  rFF is any row-wise feedforward layer, LayerNorm is layer normalization, The MAB is an adaptation of the encoder block of the Transformer without positional encoding and dropout.

  基于MAB定义SAB

SAB ⁡ ( X ) : = MAB ⁡ ( X , X ) \operatorname{SAB}(X):=\operatorname{MAB}(X, X) SAB(X):=MAB(X,X)

  SAB的时间复杂度为 O ( n 2 ) O(n^{2}) O(n2),当set ( n ≫ 1 ) (n \gg 1) (n1)较大时这个计算开销也很大。因此提出了Induced Set Attention Block (ISAB)

参考资料

评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值