Attention

本文详细探讨了如何利用自注意力机制预测序列中的下一个字符,涉及矩阵运算、CausalSelfAttention以及softmax的作用。作者解释了如何通过线性变换和softmax函数使模型中的token间有效通信,区分了self-attention与cross-attention的区别,并提及在翻译任务中的应用。
摘要由CSDN通过智能技术生成

目的:输入一个sequence, 我们预测的下一个字符要与前面的sequence都有关联,不能只看前一个字符来预测下一个。是communication mechanism。

最直观的想法:把前面的token加起来做平均

做平均等价于矩阵乘法:

# x 为输入的序列
[0.33, 0.33, 0.33] @ x
# T 为time,即序列长度
# 对w的行做平均
w = torch.ones(T, T)
w = w / w.sum(1, keepdim=True)
w @ x

又因为序列上每个位置的字符只能看到前面的字符,不能看到后面的(这是CausalSelfAttention)

# w 是下三角矩阵
w = torch.tril(torch.ones(T, T))
w = w / w.sum(1, keepdim=True)
w @ x

等价于:

tril = torch.tril(torch.ones(T, T))
w = torch.zeros((T, T))
w = w.masked_fill(tril==0, float('-inf'))
w = F.softmax(w, dim = -1)
w @ x

其中

softmax ( x i ) = exp ⁡ ( x i ) ∑ j = 1 N exp ⁡ ( x j ) \text{softmax}(x_i) = \frac{\exp(x_i)}{\sum_{j=1}^{N} \exp(x_j)} softmax(xi)=j=1Nexp(xj)exp(xi)

如果需要整个序列上的字符都能communicate to each other(encoder block),去掉masked_fill(tril==0, float('-inf'))就行

那么w如何得到呢?

w要使得sequence中的token能够communicate,我们可以用矩阵乘法来实现

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
# x = (B, T, C)
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
w = q @ k.transpose(-2, -1)  # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)

又发现如果直接用q@k.transpose(-2, -1)得到w,w的值过大,我们需要它接近1,则用 w d k \frac{w}{\sqrt{dk}} dk w实现,dk是head_size

那既然使w接近1了,为什么还有用softmax呢?

softmax可以让较大的值更明显,让diffuse的值变converge

x也进行Linear操作,得到value

最终版为:

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
# x = (B, T, C)
k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)
w = q @ k.transpose(-2, -1) * head_size ** -0.5  # (B, T, head_size) @ (B, head_size, T) --> (B, T, T)
w = w.masked_fill(tril==0, float('-inf'))
w = F.softmax(w, dim = -1)
value = nn.Linear(C, head_size, bias=False)
v = value(x)
out = w @ v

注意

  • 是每个batch里进行attention, 不同的batch无法communicate
  • self-attention指key、value与query从同一个x中产生,而cross-attention指query从x中产生,而key、value从其他地方产生。translation中,encode后的数据加入到decoder(cross-attention),使得在翻译过程中,不仅能看到前面的信息,还能看到整个句子链接的信息。
  • 23
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值