如何理解attention中的Q、K、V

参考网址:知乎

Q:Query
K:Key
V:Value

这三个变量代表什么?

其实是三个矩阵,矩阵如果表示为LxD,L是句子中词的个数,D是嵌入维度,在自注意力机制里,QKV是表示同一个句子的矩阵,否则KV一般是来自一个句子,而Q来自其他句子

如何计算QKV

我们直接用torch实现一个SelfAttention来说一说:

  1. 首先定义三个线性变换矩阵,query, key, value:
class BertSelfAttention(nn.Module):
    self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768

可以通过这三个线性变换query,key,value得到我们想要的QKV,其中三个变换的输入都是768维,输出都是768维

  1. 假设句子是“我想吃酸菜鱼”,嵌入维度是768,那么该句子表示成矩阵就是6x768维
    在这里插入图片描述

将该矩阵输入上面的三个线性转换,就可以得到三个矩阵KQV,(6x768)X(768x768)=(6x768),维度其实没有改变。
在这里插入图片描述

代码表示为:

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(L, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)

如何计算attention?

在这里插入图片描述
拿自注意力来举例(QKV都是同一个句子的矩阵)

① 首先是Q和K矩阵乘,(L, 768)*(L, 768)的转置=(L,L),看图:
在这里插入图片描述
最后得到(LxL)的矩阵,其中图中蓝色圈圈代表的就是“我”对“我”的注意力值,其他位置的值亦然。
② 然后是除以根号dim,这个dim就是768,至于为什么要除以这个数值?主要是为了缩小点积范围,确保softmax梯度稳定性,再用softmax进行归一化操作(一种解释是为了保证注意力权重的非负性,同时增加非线性)
③ 然后就是刚才的注意力权重和V矩阵乘了,如图:
在这里插入图片描述
首先是“我”这个字对“我想吃酸菜鱼”这句话里面每个字的注意力权重,和V中“我想吃酸菜鱼”里面每个字的第一维特征进行相乘再求和,这个过程其实就相当于用每个字的权重对每个字的特征进行加权求和,然后再用“我”这个字对对“我想吃酸菜鱼”这句话里面每个字的注意力权重和V中“我想吃酸菜鱼”里面每个字的第二维特征进行相乘再求和,依次类推最终也就得到了(L,768)的结果矩阵,和输入保持一致

代码:

class BertSelfAttention(nn.Module):
    def __init__(self, config):
        self.query = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.key = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
        self.value = nn.Linear(config.hidden_size, self.all_head_size) # 输入768, 输出768
    
    def forward(self,hidden_states): # hidden_states 维度是(L, 768)
        Q = self.query(hidden_states)
        K = self.key(hidden_states)
        V = self.value(hidden_states)
        
        attention_scores = torch.matmul(Q, K.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        out = torch.matmul(attention_probs, V)
        return out
  • 30
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值