$\infty$-former: Infinite Memory Transformer

Martins P., Marinho Z. and Martins A. ∞ \infty -former: Infinite Memory Transformer. arXiv preprint arXiv:2109.00301, 2021.

在transformer中引入一种长期记忆机制.

主要内容

假设 X ∈ R L × d X \in \mathbb{R}^{L \times d} XRL×d, 即每一行 x i x_i xi代表一个token对应的特征.
Attention需要进行如下的步骤:
Q = X W Q , K = X W K , V = X W V , Z = s o f t m a x ( Q K T d ) V . Q = XW^Q, K = X W^K, V = XW^V, \\ Z = \mathrm{softmax}(\frac{QK^T}{\sqrt{d}})V. Q=XWQ,K=XWK,V=XWV,Z=softmax(d QKT)V.
为了符号简易起见, 我们不考虑multi-head的情形, 下面的思想可以直接应用之.

我们知道, 可以通过径向基函数来逼近任意的连续函数:
∑ k b k ψ k ( t ) → f ( t ) . \sum_{k} b_k \psi_k (t) \rightarrow f(t). kbkψk(t)f(t).
现在, 我们令 t i = i L t_i = \frac{i}{L} ti=Li, 即对 L L L个tokens冠以时序, X X X的每一列都可以看成一个特殊的 f j ( t ) f_j(t) fj(t)的位于 t i , i = 0 , 1 , ⋯   , L − 1 t_i, i=0,1,\cdots, L-1 ti,i=0,1,,L1处的值.
给定 N N N个基函数 ψ k ( t ) , k = 0 , 1 , ⋯   , N − 1 \psi_k (t), k=0,1,\cdots, N-1 ψk(t),k=0,1,,N1, 我们要通过求解系数 b j = [ b j 0 , b j 1 , ⋯ b j , N − 1 ] T \bm{b}_j = [b_{j0}, b_{j1},\cdots b_{j,N-1}]^T bj=[bj0,bj1,bj,N1]T来逼近 f j f_j fj( X X X的第 j j j列).
Ψ ∈ R N × L , Ψ k i = ψ k ( t i ) \Psi \in \mathbb{R}^{N \times L}, \Psi_{ki}=\psi_{k}(t_i) ΨRN×L,Ψki=ψk(ti), B ∈ R d × N , B j k = b j k B \in \mathbb{R}^{d \times N}, B_{jk} = b_{jk} BRd×N,Bjk=bjk.
作者通过岭回归来求解系数 b b b:
B = arg ⁡ min ⁡ B ∥ B Ψ − X T ∥ F 2 + λ ∥ B ∥ F 2 , B = \arg \min_{B} \|B \Psi - X^T\|_F^2 + \lambda \|B\|_F^2, B=argBminBΨXTF2+λBF2,
其显示表达式为:
B = X T Ψ T ( Ψ Ψ T + λ I ) − 1 . B = X^T\Psi^T(\Psi\Psi^T + \lambda I)^{-1}. B=XTΨT(ΨΨT+λI)1.

X T ≈ B Ψ → x i ≈ B ψ ( t i ) . X^T \approx B\Psi \rightarrow x_i \approx B \psi (t_i). XTBΨxiBψ(ti).
现在我们用 X ~ : = Ψ T B T \tilde{X} := \Psi^T B^T X~:=ΨTBT来代替 X X X, 则
K = X ~ W K = Ψ T B T W K , V ~ = X ~ W V = Ψ T B T W V . K = \tilde{X} W^K = \Psi^TB^TW^K, \tilde{V} = \tilde{X}W^V = \Psi^TB^TW^V. K=X~WK=ΨTBTWK,V~=X~WV=ΨTBTWV.
注意, 我们并不对 Q Q Q进行替换, 因为这个只是用作长期的记录用, Q每次重新计算.
对于每个 q i q_i qi, 我们构建一个其关于 t t t的密度函数 p i ( t ) p_i(t) pi(t), 文中假设其满足高斯分布:
N ( t ; μ i ; σ i 2 ) . \mathcal{N}(t; \mu_i; \sigma_i^2). N(t;μi;σi2).
μ i , σ i 2 \mu_i, \sigma_i^2 μi,σi2分别通过如下估计:
μ i = s i g m o i d ( w μ T K q i ) = s i g m o i d ( w μ T B T W K q i ) , σ i 2 = s o f t p l u s ( w σ T K q i ) = s i g m o i d ( w σ T B T W K q i ) . \mu_i = \mathrm{sigmoid} (w_{\mu}^T K q_i) =\mathrm{sigmoid} (w_{\mu}^T B^TW^K q_i), \\ \sigma^2_i = \mathrm{softplus} (w_{\sigma}^T K q_i) =\mathrm{sigmoid} (w_{\sigma}^T B^TW^K q_i). \\ μi=sigmoid(wμTKqi)=sigmoid(wμTBTWKqi),σi2=softplus(wσTKqi)=sigmoid(wσTBTWKqi).
注意最后令 w T Ψ T = w T w^T\Psi^T = w^T wTΨT=wT既然 Ψ \Psi Ψ是事先确定的.
我们知道
s o f t m a x ( K q i d ) \mathrm{softmax}(\frac{Kq_i}{\sqrt{d}}) softmax(d Kqi)
实际上求解的是一个离散化的 p i ( t ) p_i(t) pi(t), 即 q i q_i qi k j k_j kj的相合程度, 而
s o f t m a x ( K q i d ) T V \mathrm{softmax}(\frac{Kq_i}{\sqrt{d}})^TV softmax(d Kqi)TV
实际上就是求解期望
E p i [ v ( t ) ] . \mathbb{E}_{p_i}[v(t)]. Epi[v(t)].
现在我们近似了一个连续的 p i ( t ) p_i(t) pi(t), 也可以通过这种方式得到最后的 z i z_i zi:
E p i [ v ( t ) ] = E p i [ ψ T ( t ) B T W V ] = E p i [ ψ T ( t ) ] B T W V . \mathbb{E}_{p_i}[v(t)] =\mathbb{E}_{p_i}[\psi^T(t)B^TW^V] =\mathbb{E}_{p_i}[\psi^T(t)]B^TW^V. Epi[v(t)]=Epi[ψT(t)BTWV]=Epi[ψT(t)]BTWV.
当我们取 ψ \psi ψ为高斯径向基函数的时候, 上述是由显示解的.

现在来剖析一下, 好在哪里?
原本的 K K K L × d L\times d L×d的, 现在由于我们只需要计算 B T W B^TW BTW, 故实际上只有 N × d N \times d N×d, 我们可以选取很大的 L L L但是选择较小的 N N N来避免较高的复杂度.

如何扩展?

难不成每一次都要重新计算 B B B? 倘若真的是这样就谈不上是长期记忆了.
作者采取了一种比较巧的方法, 实际上, 现在的 B ψ ( t ) B\psi(t) Bψ(t)可以看成是一个 d d d维的向量函数.
我们首先将其进行压缩至 [ 0 , τ ] , τ ∈ ( 0 , 1 ) [0, \tau], \tau \in (0, 1) [0,τ],τ(0,1):
B ψ ( t / τ ) , B\psi(t /\tau), Bψ(t/τ),
如此一来, 整个函数的能量集中在 [ 0 , τ ] [0, \tau] [0,τ]中, 我们可以用剩下的 ( τ , 1 ] (\tau, 1] (τ,1]来放置新的 X X X.
我们首先从 [ 0 , τ ] [0, \tau] [0,τ]中采样 M M M个点 t 0 , ⋯   , t M − 1 t_0, \cdots, t_{M-1} t0,,tM1, 并得到:
X p a s t = [ x 0 , ⋯   , x M − 1 ] T ∈ R M × d , x m = ψ T ( t m / τ ) B T . X_{past} = [x_0, \cdots, x_{M-1}]^T \in \mathbb{R}^{M \times d}, x_m=\psi^T(t_m/\tau)B^T. Xpast=[x0,,xM1]TRM×d,xm=ψT(tm/τ)BT.
加上新的 X n e w X_{new} Xnew, 我们有
X = [ X p a s t T , X n e w T ] T ∈ R ( M + L ) × d , X = [X_{past}^T, X_{new}^T]^T \in \mathbb{R}^{(M + L) \times d}, X=[XpastT,XnewT]TR(M+L)×d,
X X X按照上面的逻辑重新估计 B B B即可更新记忆.

关于如何采样这 M M M个点, 作者提了一种sticky memories的方法, 将其与密度函数联系在一起, 便不细讲了.

实验细节

在看这篇论文的时候, 困扰我的就是这个径向基函数是怎么选的?
举一个作者在Language Modeling中的例子便可:
选取150个高斯径向基函数 N ( t ; μ , σ 2 ) \mathcal{N}(t;\mu, \sigma^2) N(t;μ,σ2), 其中
μ \mu μ [ 0 , 1 ] [0, 1] [0,1]中均匀采样, σ ∈ { 0.01 , 0.05 } \sigma \in \{0.01, 0.05\} σ{0.01,0.05}.

还有用KL散度防止一般化就不讲了. 感觉本文有趣的点就是压缩这个地方, 还有对 Ψ \Psi Ψ的处理.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值