这篇报告是一篇全面性的综述,详细介绍了目前用来解释 Transformer 语言模型内部运作机制的技术。本文聚焦在生成式的 decoder-only 架构。作者在结论部分概述了这些模型的已知内部机制,揭示了当前流行方法和这个领域的积极研究方向之间的联系。
论文出处: https://arxiv.org/pdf/2405.00208
数学符号说明
在本文中,我们采用以下数学符号:
-
n n n:序列长度
-
V V V:词汇表
-
t = ⟨ t 1 , t 2 , … , t n ⟩ t = \langle t_1, t_2, \ldots, t_n \rangle t=⟨t1,t2,…,tn⟩:输入的 token 序列
-
x = ⟨ x 1 , x 2 , … , x n ⟩ x = \langle x_1, x_2, \ldots, x_n \rangle x=⟨x1,x2,…,xn⟩:输入的 token 嵌入序列
-
d d d:模型维度
-
d h d_h dh:注意力头维度
-
d F F N d_{FFN} dFFN:前馈网络维度
-
H H H:注意力头数量
-
L L L:层数
-
x l , i ∈ R d x_{l,i} \in \mathbb{R}^d xl,i∈Rd:在位置 i i i,第 l l l 层的残差串流状态
-
x mid , l , i ∈ R d x_{\text{mid},l,i} \in \mathbb{R}^d xmid,l,i∈Rd:在位置 i i i,第 l l l 层,经过注意力区块后的残差串流状态
-
f c ( x ) ∈ R d f_c(x) \in \mathbb{R}^d fc(x)∈Rd:在最后一个位置,组件 c c c 的输出表示
-
f l ( x ) = x l , n ∈ R d f_l(x) = x_{l,n} \in \mathbb{R}^d fl(x)=xl,n∈Rd:在最后一个位置,第 l l l 层的残差串流状态
-
A l , h ∈ R n × n A_{l,h} \in \mathbb{R}^{n \times n} Al,h∈Rn×n:在第 l l l 层第 h h h 个注意力头的注意力矩阵
-
W l , h Q , W l , h K , W l , h V ∈ R d × d h W_{l,h}^Q, W_{l,h}^K, W_{l,h}^V \in \mathbb{R}^{d \times d_h} Wl,hQ,Wl,hK,Wl,hV∈Rd×dh:在第 l l l 层第 h h h 个注意力头的 queries、keys 和 values 权重矩阵
-
W l , h O ∈ R d h × d W_{l,h}^O \in \mathbb{R}^{d_h \times d} Wl,hO∈Rdh×d:在第 l l l 层第 h h h 个注意力头的输出权重矩阵
-
W l in ∈ R d × d F F N , W l out ∈ R d F F N × d W_{l}^{\text{in}} \in \mathbb{R}^{d \times d_{FFN}}, W_{l}^{\text{out}} \in \mathbb{R}^{d_{FFN} \times d} Wlin∈Rd×dFFN,Wlout∈RdFFN×d:在第 l l l 层前馈网络的输入和输出权重矩阵
-
W E ∈ R d × ∣ V ∣ W_E \in \mathbb{R}^{d \times |V|} WE∈Rd×∣V∣ 和 W U ∈ R ∣ V ∣ × d W_U \in \mathbb{R}^{|V| \times d} WU∈R∣V∣×d:嵌入(Embedding)和去嵌入(Unembedding)矩阵
一个仅译码器的模型 f f f 有 L L L 层,对一个序列的嵌入 x = ⟨ x 1 , x 2 , … , x n ⟩ x = \langle x_1, x_2, \ldots, x_n \rangle x=⟨x1,x2,…,xn⟩ 进行操作,这些嵌入表示 tokens t = ⟨ t 1 , t 2 , … , t n ⟩ t = \langle t_1, t_2, \ldots, t_n \rangle t=⟨t1,t2,…,tn⟩。每个嵌入 x ∈ R d x \in \mathbb{R}^d x∈Rd 是嵌入矩阵 W E ∈ R ∣ V ∣ × d W_E \in \mathbb{R}^{|V| \times d} WE∈R∣V∣×d 的一个列向量,其中 V V V 是模型词汇表。中间层的表示,例如在位置 i i i 和层 l l l,表示为 x l , i x_{l,i} xl,i。 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d 代表将序列 x x x 表示为一个矩阵,嵌入堆栈为列。同样地,对于中间表示, X l , ≤ i X_{l, \leq i} Xl,≤i 是层 l l l 的表示矩阵,最多到位置 i i i。
遵循最近关于Transformers可解释性的文献,本文采用残差串流的观点来呈现架构。在这个观点下,每个输入嵌入都会透过注意力区块和前馈网络区块的向量相加来更新,产生残差串流状态(或中间表示)。最后一层的残差串流状态会透过去嵌入矩阵 W U ∈ R d × ∣ V ∣ W_U \in \mathbb{R}^{d \times |V|} WU∈Rd×∣V∣ 投影到词汇空间,并透过softmax函数正规化以获得词汇上的机率分布,从中取样出新的token。
在Transformer层中,LayerNorm是一种常见的操作,用于稳定深度神经网络的训练过程。给定一个表示 z z z,LayerNorm的计算为 ( z − μ ( z ) ) / σ ( z ) ⊙ γ + β (z-\mu(z))/\sigma(z) \odot \gamma + \beta (z−μ(z))/σ(z)⊙γ+β,其中 μ \mu μ 和 σ \sigma σ 分别计算平均值和标准偏差, γ ∈ R d \gamma \in \mathbb{R}^d γ∈Rd 和 β ∈ R d \beta \in \mathbb{R}^d β∈Rd 是学习到的逐元素转换和偏差。LayerNorm可以从几何角度解释,将减去平均值的操作视为将输入表示投影到由法向量 [ 1 , 1 , … , 1 ] ∈ R d [1, 1, \ldots, 1] \in \mathbb{R}^d [1,1,…,1]∈Rd 定义的超平面上,然后将结果表示映射到 d \sqrt{d} d 范数的超球面。
注意力区块由多个注意力头组成。在解碼步骤 i i i,每个注意力头从先前位置(≤ i i i)的残差串流读取,决定要关注哪些位置,从那些位置收集信息,最后将其写入当前的残差串流。
注意力头计算如下:
Attn l , h ( X l − 1 , ≤ i ) = ∑ j ≤ i a l , h , i , j x l − 1 , j W l , h V W l , h O = ∑ j ≤ i a l , h , i , j x l − 1 , j W l , h O V \text{Attn}_{l,h}(X_{l-1, \leq i}) = \sum_{j \leq i} a_{l,h,i,j} x_{l-1,j} W_{l,h}^V W_{l,h}^O = \sum_{j \leq i} a_{l,h,i,j} x_{l-1,j} W_{l,h}^{OV} Attnl,h(Xl−1,≤i)=j≤i∑al,h,i,jxl−1,jWl,hVWl,hO=j≤i∑al,h,i,jxl−1,jWl,hOV
可学习的权重矩阵 W l , h V ∈ R d × d h W_{l,h}^V \in \mathbb{R}^{d \times d_h} Wl,hV∈Rd×dh 和 W l , h O ∈ R d h × d W_{l,h}^O \in \mathbb{R}^{d_h \times d} Wl,hO∈Rdh×d 合并为 OV 矩阵 W l , h V W l , h O = W l , h O V ∈ R d × d W_{l,h}^V W_{l,h}^O = W_{l,h}^{OV} \in \mathbb{R}^{d \times d} Wl,hVWl,hO=Wl,hOV∈Rd×d,也称为 OV (output-value) 电路。对于每个查询(query) i i i,给定当前的键(key)( ≤ i \leq i ≤i),注意力权重计算为:
a l , h , i = softmax ( x l − 1 , i W l , h Q ( X l − 1 , ≤ i W l , h K ) T d k ) a_{l,h,i} = \text{softmax}\left( \frac{x_{l-1,i} W_{l,h}^Q (X_{l-1, \leq i} W_{l,h}^K)^T}{\sqrt{d_k}} \right) al,h,i=softmax(dkxl−1,iWl,hQ(Xl−1,≤iWl,hK)T)
其中 W l , h Q ∈ R d × d h W_{l,h}^Q \in \mathbb{R}^{d \times d_h} Wl,hQ∈Rd×dh 和 W h Q ( W h K ) T = W h Q K ∈ R d × d W_{h}^Q (W_{h}^K)^T = W_{h}^{QK} \in \mathbb{R}^{d \times d} WhQ(WhK)T=WhQK∈Rd×d 组合为 QK (query-key) 电路 W h Q ( W h K ) T = W h Q K ∈ R d × d W_{h}^Q (W_{h}^K)^T = W_{h}^{QK} \in \mathbb{R}^{d \times d} WhQ(WhK)T=WhQK∈Rd×d
QK 和 OV 电路可以视为负责分别从残差串流读取和写入的单元。注意力区块的输出是各个注意力头的总和,随后加回到残差串流:
Attn l ( X l − 1 , ≤ i ) = ∑ h = 1 H Attn l , h ( X l − 1 , ≤ i ) \text{Attn}_l(X_{l-1, \leq i}) = \sum_{h=1}^H \text{Attn}_{l,h}(X_{l-1, \leq i}) Attnl(Xl−1,≤i)=h=1∑HAttnl,h(Xl−1,≤i)
x mid , l , i = x l − 1 , i + Attn l ( X l − 1 , ≤ i ) x_{\text{mid},l,i} = x_{l-1,i} + \text{Attn}_l(X_{l-1, \leq i}) xmid,l,i=xl−1,i+Attnl(Xl−1,≤i)
前馈网络区块由两个可学习的权重矩阵组成: W l in ∈ R d × d F F N W_{l}^{\text{in}} \in \mathbb{R}^{d \times d_{FFN}} Wlin∈Rd×dFFN 和 W l out ∈ R d F F N × d W_{l}^{\text{out}} \in \mathbb{R}^{d_{FFN} \times d} Wlout∈RdFFN×d。 W l in W_{l}^{\text{in}} Wlin 从残差串流状态 x mid , l , i x_{\text{mid},l,i} xmid,l,i 读取,其结果通过一个逐元素的非线性激活函数 g g g,产生神经元激活。这些激活再由 W l out W_{l}^{\text{out}} Wlout 转换以产生输出 FFN ( x mid , i ) \text{FFN}(x_{\text{mid},i}) FFN(xmid,i),然后加回到残差串流:
FFN l ( x mid , l , i ) = g ( x mid , l , i W l in ) W l out \text{FFN}_l(x_{\text{mid},l,i}) = g(x_{\text{mid},l,i} W_{l}^{\text{in}}) W_{l}^{\text{out}} FFNl(xmid,l,i)=g(xmid,l,iWlin)Wlout
x l , i = x mid , l , i + FFN l ( x mid , l , i ) x_{l,i} = x_{\text{mid},l,i} + \text{FFN}_l(x_{\text{mid},l,i}) xl,i=xmid,l,i+FFNl(xmid,l,i)
前馈网络的计算可以等同于键值内存检索, W l in W_{l}^{\text{in}} Wlin 中的列向量充当输入序列上的模式检测器(键),每个神经元激活加权了 W l out W_{l}^{\text{out}} Wlout 的列向量(值)。元素式非线性在前馈网络内部创建了一个特权基底,鼓励特征与基底方向对齐。
预测层包含一个去嵌入矩阵 W U ∈ R d × ∣ V ∣ W_U \in \mathbb{R}^{d \times |V|} WU∈Rd×∣V∣,有时还有一个偏差。最后一个残差串流状态透过这个线性映像转换,将表示转换为下一个 token 的 logits 分布,再通过 softmax 函数转换为概率分布。由于模型组件透过相加与残差串流交互作用,未正规化的分数(logits)是透过组件输出的线性投影获得的。基于线性变换的性质,可以重新排列传统的前向传递公式,使每个模型组件直接贡献于预测的 token 的 logits:
f ( x ) = x L n W U = ( ∑ l = 1 L ∑ h = 1 H Attn l , h ( X l − 1 ≤ n ) + ∑ l = 1 L FFN l ( x mid , l n ) + x n ) W U f(x) = x_L^n W_U = \left( \sum_{l=1}^L \sum_{h=1}^H \text{Attn}_{l,h}(X_{l-1}^{\leq n}) + \sum_{l=1}^L \text{FFN}_l(x_{\text{mid},l}^n) + x^n \right) W_U f(x)=xLnWU=(l=1∑Lh=1∑HAttnl,h(Xl−1≤n)+l=1∑LFFNl(xmid,ln)+xn)WU
= ∑ l = 1 L ∑ h = 1 H Attn l , h ( X l − 1 ≤ n ) W U + ∑ l = 1 L FFN l ( x mid , l n ) W U + x n W U = \sum_{l=1}^L \sum_{h=1}^H \text{Attn}_{l,h}(X_{l-1}^{\leq n}) W_U + \sum_{l=1}^L \text{FFN}_l(x_{\text{mid},l}^n) W_U + x^n W_U =l=1∑Lh=1∑HAttnl,h(Xl−1≤n)WU+l=1∑LFFNl(xmid,ln)WU+xnWU
注意力头 logits 更新 前馈网络 logits 更新
这个分解在定位负责预测的组件时扮演重要角色,因为它允许测量每个组件对预测 token 的 logits 的直接贡献。
残差网络就像浅层网络的集成,每个子网定义了计算图中的一条路径。将前向传递分解为:
f ( x ) = x W U + x W 1 O V W U + x W 1 O V W 2 O V W U + x W 2 O V W U f(x) = xW_U + xW_1^{OV} W_U + xW_1^{OV} W_2^{OV} W_U + xW_2^{OV} W_U f(x)=xWU+xW1OVWU+xW1OVW2OVWU+xW2OVWU
直接路径 完整 OV 电路
虚拟注意力头 (V-composition)
链接输入嵌入与去嵌入矩阵的路径称为直接路径。穿过单个 OV 矩阵的路径称为完整 OV 电路。涉及两个注意力头的路径称为虚拟注意力头,执行 V-composition,因为两个头的顺序写入和读取被视为 OV 矩阵的组合。
理解语言模型的内部运作机制意味着定位前向传递中负责特定预测的元素(输入元素、表示和模型组件)。本文介绍了两种不同类型的方法,允许定位模型行为:输入归因和模型组件归因。
输入归因方法通常用于通过估计输入元素(在语言模型的情况下是 tokens)对定义模型预测的贡献来定位模型行为。对于像语言模型这样的神经网络模型,梯度信息经常被用作归因目的的自然度量。基于梯度的归因在这个上下文中涉及 Transformer 在点 x x x 处的一阶泰勒展开,表示为 ∇ f ( x ) ⋅ x + b \nabla f(x) \cdot x + b ∇f(x)⋅x+b。结果梯度 ∇ f w ( x ) ∈ R n × d \nabla f_w(x) \in \mathbb{R}^{n \times d} ∇fw(x)∈Rn×d 直观地捕捉了模型对输入中的每个元素在预测 token w w w 时的敏感度。虽然归因分数是针对输入 token 嵌入的每个维度计算的,但它们通常在 token 层面汇总,以获得对个别 token 影响的更直观概述。这通常是透过取梯度向量相对于第 i i i 个输入嵌入的 L p L_p Lp 范数来完成的:
A Grad , f w ( x → t i ) = ∥ ∇ x i f w ( x ) ∥ p A_{\text{Grad},f_w}^{(x \rightarrow t_i)} = \| \nabla_{x_i} f_w(x) \|_p AGrad,fw(x→ti)=∥∇xifw(x)∥p
透过梯度向量与输入嵌入 ∇ x i f w ( x ) ⋅ x i \nabla_{x_i} f_w(x) \cdot x_i ∇xifw(x)⋅xi 的点积,称为梯度 × 输入方法,可以将这种敏感度转化为重要性估计。然而,这些方法已知会出现梯度饱和和碎裂问题。这一事实促使引入了积分梯度和 SmoothGrad 等方法来过滤嘈杂的梯度信息。例如,积分梯度近似基线输入 x ~ \tilde{x} x~ 和输入 x x x 之间直线路径上的梯度积分:
∫ 0 1 ∇ x i f w ( x ~ + α ( x − x ~ ) ) d α \int_0^1 \nabla_{x_i} f_w(\tilde{x} + \alpha(x - \tilde{x})) d\alpha ∫01∇xifw(x~+α(x−x~))dα
随后提出了适应文本输入离散性的改编。最后,基于层次相关传播 (Layer-wise Relevance Propagation, LRP) 的方法已广泛应用于研究基于 Transformer 的语言模型。这些方法对梯度传播使用自定义规则,以分解每一层的组件贡献,确保它们的总和在整个网络中保持不变。
另一个流行的方法族是透过添加噪声或消融输入元素并测量对模型预测的结果影响来估计输入重要性。例如,可以移除位置 i i i 处的输入 token,结果机率差异 f w ( x ) − f w ( x − x i ) f_w(x) - f_w(x_{-x_i}) fw(x)−fw(x−xi) 可以用作其重要性的估计。如果给予 w w w 的 logit 或机率没有改变,我们可以得出第 i i i 个 token 没有影响的结论。
虽然原始模型内部数据(如注意力权重)通常被认为提供了不忠实的模型行为解释,但最近的方法提出了注意力权重的替代方案来测量中间 token 级别的归因。其中一些替代方案包括使用值加权向量和输出值加权向量的范数,或使用向量距离来估计贡献。这些方法的共同策略涉及使用注意力展开等技术聚合反映上下文混合模式的中间每层归因,得到输入归因分数。
一个重要的限制是,归因的输出 token 属于一个大的词汇空间,在下一个词预测中经常有语义上等价的 token 竞争机率质量。在这种情况下,归因分数很可能误代了驱动模型预测的几个重迭因素,如语法正确性和语义适当性。最近的工作通过提出这些方法的对比公式来解决这个问题,为模型预测 token w w w 而不是替代 token o o o 提供反事实解释。
另一个输入归因的维度涉及识别影响推理时特定模型预测的有影响力的训练样本。这些方法通常被称为训练数据归因 (TDA) 或实例归因方法,并被用来识别数据中的人工因素和语言模型预测中偏差的来源。最近的方法提出通过训练运行模拟来执行 TDA。虽然已建立的 TDA 方法的适用性受到质疑,特别是由于其低效率,但这个领域最近的工作产生了更有效的方法,可以大规模应用于大型生成模型。
早期关于 Transformer 语言模型组件重要性的研究强调了模型能力的高度稀疏性。例如,即使删除模型中相当一部分的注意力头,也可能不会使其下游性能恶化。这些结果激发了一条新的研究路线,研究语言模型中的各种组件如何贡献于其广泛的能力。
让我们称组件 c c c(注意力头或前馈网络)在特定层上对最后一个 token 位置的输出表示为 f c ( x ) f_c(x) fc(x)。等式 (10) 中提出的分解允许我们测量每个模型组件对输出 token w ∈ V w \in V w∈V 的直接 logit 归因 (DLA):
A D L A , f w ( x → c ) = f c ( x ) W U [ : , w ] A_{DLA,f_w}^{(x \rightarrow c)} = f_c(x) W_U[:,w] ADLA,fw(x→c)=fc(x)WU[:,w]
其中 W U [ : , w ] W_U[:,w] WU[:,w] 是 W U W_U WU 的第 w w w 列,即 token w w w 的去嵌入向量。实际上,组件 c c c 的 DLA 表示 c c c 对预测 token 的 logit 的贡献,使用第 2.2 节中描述的模型组件的线性特性。
我们可以将模型的计算视为一个因果模型,并使用因果工具来阐明每个模型组件 c ∈ C c \in C c∈C 在不同位置对预测的贡献。因果模型可以看作是一个有向无环图 (DAG),其中节点是模型计算,边是激活。
我们可以通过改变前向传递中由模型组件计算的某些节点值 f c ( x ) f_c(x) fc(x) 到另一个值 h ~ \tilde{h} h~ 来干预模型,这被称为激活修补。我们可以使用 do 算子表示这个干预: f ( x ∣ do ( f c ( x ) = h ~ ) ) f(x \mid \text{do}(f_c(x) = \tilde{h})) f(x∣do(fc(x)=h~))。然后我们测量修补后预测的变化:
A Patch , f ( x → c ) = diff ( f ( x ) , f ( x ∣ do ( f c ( x ) = h ~ ) ) ) A_{\text{Patch}, f}^{(x \rightarrow c)} = \text{diff}(f(x), f(x \mid \text{do}(f_c(x) = \tilde{h}))) APatch,f(x→c)=diff(f(x),f(x∣do(fc(x)=h~)))
其中 diff ( ⋅ , ⋅ ) \text{diff}(\cdot, \cdot) diff(⋅,⋅) 函数的常见选择包括 KL 散度和 logit/机率差异。修补后的激活 ( h ~ \tilde{h} h~) 可以来自各种来源。一种常见的方法是创建一个具有分布 P patch P_{\text{patch}} Ppatch 的反事实数据集,其中一些关于任务的输入信号被反转。这种方法会导致两种不同类型的消融:
-
重采样干预,其中修补后的激活是从 P patch P_{\text{patch}} Ppatch 的单个样本中获得的,即 h ~ = f c ( x ~ ) , x ~ ∼ P patch \tilde{h} = f_c(\tilde{x}), \tilde{x} \sim P_{\text{patch}} h~=fc(x~),x~∼Ppatch。
-
平均干预,其中多个 P patch P_{\text{patch}} Ppatch 样本的平均激活用于修补,即 h ~ = E x ~ ∼ P patch [ f c ( x ~ ) ] \tilde{h} = \mathbb{E}_{\tilde{x} \sim P_{\text{patch}}}[f_c(\tilde{x})] h~=Ex~∼Ppatch[fc(x~)]。
另外,修补激活的其他来源包括:
-
零干预,其中激活被替换为空向量,即 h ~ = 0 \tilde{h} = 0 h~=0。
-
噪声干预,其中新的激活是通过在受扰动的输入上运行模型获得的,例如 h ~ = f c ( x + ϵ ) , ϵ ∼ N ( 0 , σ 2 ) \tilde{h} = f_c(x + \epsilon), \epsilon \sim N(0, \sigma^2) h~=fc(x+ϵ),ϵ∼N(0,σ2)。
在设计因果干预实验时需要考虑的一个重要因素是设置的生态效度,因为零消融和噪声消融可能使模型偏离自然激活分布,最终破坏组件分析的有效性。
机制可解释性 (Mechanistic Interpretability) 子领域专注于将神经网络逆向工程为人类可理解的算法。MI 的最新研究旨在揭示电路的存在,电路是一组共同交互作用以解决任务的模型组件(子图)。激活修补、logit 归因和注意力模式分析是电路发现的常用技术。
边缘修补和路径修补利用每个模型组件输入是其残差串流中先前组件输出之和这一事实,并考虑直接连接成对模型组件节点的边。路径修补将边缘修补方法推广到多条边,允许更细粒度的分析。例如,使用等式中描述的浅层网络的前向传递分解,我们可以将图中的单层 Transformer 视为由以下组件组成:
f ( x ) = Attn ( X ≤ n ) W u + FFN ( Attn ( X ≤ n ) + x n ) W u + x n W u f(x) = \text{Attn}(X^{\leq n}) W_u + \text{FFN}(\text{Attn}(X^{\leq n}) + x_n)W_u + x_n W_u f(x)=Attn(X≤n)Wu+FFN(Attn(X≤n)+xn)Wu+xnWu
这里的各个路径包括:
-
直接从Attn到logits。
-
通过FFN再到logits的间接路径。
其中每个发送节点Attn L ( X L − 1 ≤ n ) _L(X_{L-1}^{\leq n}) L(XL−1≤n)的副本都与单个路径相关。在这个例子中,分别修补每个发送节点副本允许我们估计Attn L ( X L − 1 ≤ n ) _L(X_{L-1}^{\leq n}) L(XL−1≤n)对输出logits f ( x ) f(x) f(x)的直接和间接效应。一般来说,我们可以将路径修补应用于网络中的任何路径,并测量头部之间的组合、FFN或这些组件对logits的影响。
基于因果干预的电路分析存在几个缺点:
-
它需要为要评估的任务设计输入模板,以及反事实数据集(即定义 P patch P_{\text{patch}} Ppatch)方面的大量努力。
-
在获得组件重要性估计后,需要人工检查和领域知识来分离重要的子图。
-
已经表明,干预会在下游组件的行为中产生二阶效应,在某些情况下甚至会引发类似于自我修复的补偿行为。这种现象会使得难以得出关于每个组件作用的结论。
为了克服这些限制,Conmy等人提出了一种自动电路发现(ACDC)算法,通过迭代移除计算图中的边来自动识别电路的过程。然而,这个过程需要大量的前向传递(每个修补元素一次),在研究大型模型时变得不切实际。
修补的一个有效替代方法是基于梯度的方法,这些方法已经扩展到超越输入归因,以计算中间模型组件的重要性。例如,给定token预测 w w w,为了计算中间层 l l l的归因,表示为 f l ( x ) f_l(x) fl(x),计算梯度 ∇ f w ( f l ( x ) ) \nabla f_w(f_l(x)) ∇fw(fl(x))。Sarti等人将等式中的对比梯度归因公式扩展到使用单次前向和后向传递来定位对正确延续的预测贡献大于错误延续的组件。Nanda等人提出了边缘归因修补(EAP),包括修补前后预测差异的线性近似,以估计计算图中每个边的重要性。这种方法的主要优点是它只需要两次前向传递和一次后向传递就可以获得图中每个边的归因分数。
另一个研究方向是在较低层次的神经网络中寻找可解释的高层次因果抽象概念。这些方法涉及大量的计算搜索,并假设高层次变量与单元或神经元组对齐。为了克服这些限制,Geiger等人提出了分布式对齐搜索(Distributed Alignment Search, DAS),它在通过梯度下降找到的低层次表示空间的非基底对齐子空间上执行分布式交换干预(Distributed Intervention Interchange, DII)。DAS干预在使用语法评估寻找具有因果影响的特征方面被证明是有效的,同时在分离实体的个别属性的因果效应方面也有显著效果。
最后,在Transformer语言模型的内部行为方面,本文总结了以下主要发现:
注意力区块:
位置头:有些头主要关注相对于正在处理的token的特定位置,如token本身、前一个token或下一个位置。
子词连接头: 专门关注属于与当前处理的token相同词的前一个子词token。
语法头: 一些头关注具有与被处理token相关语法角色的token,明显多于随机基准。
重复token: 关注上下文中同一个token的先前出现。
复制头: OV矩阵表现出复制行为。
归纳头: 由两个在不同层组合的头组成,让语言模型完成模式。一个早期的前一个token头将第一个token A的信息复制到B的残差流,然后一个下游的归纳头关注token B,增加B的可能性。
复制抑制头: 如果它们出现在上下文中并且当前残差流正在自信地预测它,则减少它们关注的token的logit分数。
前馈网络区块:
神经元的输入行为:有些神经元仅在特定位置范围上激发;技能神经元,其激活与输入提示的任务相关;概念特定神经元,其反应可用于预测上下文中概念的存在。
神经元的输出行为:有些神经元促进与特定语义和句法概念相关的token的预测;一小部分后层神经元负责做出在语言上可接受的预测;抑制不可能延续的神经元。
多语义神经元: 早期层的大多数神经元专门用于n-gram集合,充当n-gram侦测器,大多数神经元在大量n-gram上激发。
残差流:
残差流可以被视为Transformer中的主要信道。直接路径主要对应于bi-gram统计,而网络中的最新偏差根据词频转移预测,促进高频token。
一些组件执行内存管理,以移除存储在残差流中的信息。例如,有负特征值的OV矩阵关注当前位置的注意力头,以及输入和输出权重具有较大负余弦相似性的前馈网络神经元。
在残差流中发现了离群值维度。这些维度展现出相对于其他维度的大幅度,与各向异性表示的生成相关联。消融离群值维度已被证明会显著降低下游性能,表明它们编码任务特定的知识。
多组件行为:
归纳机制: 是两个组件(注意力头)组合在一起以完成模式的一个明显例子。最近的证据表明,多个注意力头协同工作,在给定上下文样本时创建描述任务的「函数」或「任务」向量。
Variengien 和 Winsor 研究了涉及回答可以在上下文中找到答案的请求的上下文检索
任务:作者确定了一种在子任务和模型之间通用的高层次机制。具体而言,中间层处理请求,然后由后层的注意力头执行从上下文中检索实体的步骤。
在 GPT-2 Small 中发现了用于间接宾语识别 (IOI) 任务的电路,主要包括:
重复信号:重复token头和涉及前一个token头的归纳机制表明S(John)的重复性。这个信息被最后位置的S-抑制头读取,它们在残差流中写入一个token信号,表明S被重复,以及S1 token的位置信号。
名称复制: 后层的名称移动头将它们在上下文中关注的名称的信息复制到最后的残差流。然而,先前层S-抑制头的信号修改了名称移动头的query,使得重复的名称(在S1和S2中)受到较少关注,有利于复制间接宾语(IO),从而推动其预测。
我的感想
这篇论文全面而深入地介绍了目前用于解释 Transformer 语言模型内部运作的技术,并总结了通过这些方法得到的关于模型内部机制的见解。作者强调,虽然在可解释性研究方面取得了显著进展,但将这些见解应用于调试和改进未来模型的安全性和可靠性,为开发人员和用户提供更好的工具来与之交互并理解影响其预测的因素,仍然是一个巨大的挑战。未来可解释性研究的发展将面临从在模型组件空间运作的方法和分析转向人类可解释空间(即从模型组件到特征和自然语言解释)的挑战性任务,同时仍然忠实地反映模型行为。
此外,跨学科研究将在扩大可解释性分析的范围方面发挥关键作用,以考虑从人的角度看模型解释的感知和交互维度。最终,作者认为,确保对先进语言模型的内部机制的开放和便利访问,将仍然是这一领域未来进展的基本先决条件。
这篇综述论文对于理解 Transformer 语言模型的内部运作机制,以及目前在这个领域的最新研究进展,提供了全面而详尽的概览。透过系统性地介绍各种可解释性技术,并深入探讨它们揭示的模型内部行为,本文为相关研究人员提供了宝贵的参考。
不过,正如作者所指出,将这些见解应用到实际中仍面临诸多挑战。未来的可解释性研究需要在忠实反映模型行为的同时,努力向更加贴近人类直观理解的方向发展。这需要不同学科领域的通力合作。此外,开放模型内部机制的访问,或许是这一领域能够取得突破性进展的关键。