[NLP]——Transformer中的attention为什么要做scale?

前言

说起Transformer的self-attention,很容易想到下面的公式:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T / d k ) V Attention(Q,K,V) = softmax(QK^T/\sqrt{d_k})V Attention(Q,K,V)=softmax(QKT/dk )V

假设X是输入,分别经过 W Q W_Q WQ W K W_K WK W V W_V WV映射得到 Q Q Q K K K V V V,【 d k d_k dk=Q.size(-1)=K.size(-1)】,通过 Q Q Q K K K的dot-product来计算 w e i g h t weight weight,经过softmax得到归一化后的 w e i g h t weight weight,再使用这个 w e i g h t weight weight去对 V V V做加权求和。那么,公式 Q K T / d k QK^T/\sqrt{d_k} QKT/dk 中的 d k \sqrt{d_k} dk 是用来干嘛的呢?

一句话概括就是:如果不对softmax的输入做缩放,那么万一输入的数量级很大,softmax的梯度就会趋向于0,导致梯度消失。

本文的思路如下:

  • softmax是怎样求导的?
  • 通过对softmax求导,我们可以知道,softmax的输入的数量级越大,求导的梯度越会趋向于0
  • 如何控制softmax的输入的数量级(也就是 d k \sqrt{d_k} dk 的作用)

softmax的求导过程

这部分主要参考:详解softmax函数以及相关求导过程 - 忆臻的文章 - 知乎
https://zhuanlan.zhihu.com/p/25723112

假设softmax的输入是 x = ( x 1 , x 2 , . . . x n ) x=(x_1,x_2,...x_n) x=(x1,x2,...xn),则 y = s o f t m a x ( x ) = ( y 1 , y 2 , . . . y n ) y = softmax(x) = (y_1,y_2,...y_n) y=softmax(x)=(y1,y2,...yn),其中, y i y_i yi = e x i / ∑ k = 1.. n e x k e^{x_i}/\sum_{k=1..n}e^{x_k} exi/k=1..nexk。求导过程如下(打公式好累,有个错别字“妨”):
在这里插入图片描述

为softmax的输入的数量级越大,求导的梯度越会趋向于0

这部分主要参考transformer中的attention为什么scaled? - TniL的回答 - 知乎
https://www.zhihu.com/question/339723385/answer/782509914

首先,对于输入的 x = ( x 1 , x 2 , . . . x n ) x=(x_1,x_2,...x_n) x=(x1,x2,...xn),softmax中的max体现在:通过一个自然底数e来将输入中最大的元素更大,softmax中的soft体现在:不忽略输入中的最小的元素,依然给它们一定的权重。总之,softmax函数能够将输入中的元素间差距拉大,然后归一化为一个分布。

假设输入的 x = ( x 1 , x 2 , . . . x n ) x=(x_1,x_2,...x_n) x=(x1,x2,...xn)中最大元素为 x k x_k xk,其对应的概率输出为 y k y_k yk,将会呈现“x的数量级越大, y k y_k yk越趋向于1”的趋势,具体举例如下:
在这里插入图片描述
在这种情况下,如果输入 x x x的数量级很大,而假设它的最大值是 x 0 x_0 x0,则经过softmax计算得到的 y y y中,只有 y 0 y_0 y0趋向于1,其它概率元素全都趋向于0。进一步结合第一部分的求导结果,将会出现下图的情况:
在这里插入图片描述
即,softmax的梯度趋向于0

如何控制softmax的输入的数量级

那么如何控制softmax的输入,也就是 x = ( x 1 , x 2 , . . . x n ) x=(x_1,x_2,...x_n) x=(x1,x2,...xn)的数量级呢?首先明确一点,在Transformer中,如果没有 d k \sqrt{d_k} dk ,则,softmax的输入是 Q K T QK^T QKT。宏观的,我们需要保证 Q K T QK^T QKT,也就是一个(batch_size x)?sent_num x sent_n的矩阵,其中的每一个元素的数量级都不要很大。那么取出其中一个元素,它由 q ⋅ k q·k qk计算得来,接下来证明 q ⋅ k q·k qk的数量级与 d k \sqrt{d_k} dk 的关系。

下面的推导设计到概率论与数理统计的知识,还没有复习到相关内容,所以依然先参考:transformer中的attention为什么scaled? - TniL的回答 - 知乎
https://www.zhihu.com/question/339723385/answer/782509914

经过上述推导,得到结论:
在这里插入图片描述

参考

transformer中的attention为什么scaled? - TniL的回答 - 知乎
https://www.zhihu.com/question/339723385/answer/782509914

详解softmax函数以及相关求导过程 - 忆臻的文章 - 知乎
https://zhuanlan.zhihu.com/p/25723112

  • 10
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值