Transfrom中的多头注意力机制

一,Self-Attention机制

        Self-Attention可以使网络关注到不同位置向量之间的关系,例如The animal didn‘t cross the street because it was too tired. 中 it 特指的是哪个部分,使animal还是street。举个简单的例子说明,例如有两个单词Machine Learning。Self-Attention的注意力机制会产生以下4个关系,Machine-Machine,Machine-Learning,Learning-Machine,Learning-Learning。具体的计算方式如下:

        将词汇通过embedding转换成词向量分别是X1和X2,接着再通过线性变换得到Q,K,V向量,分别是查询,键和值。

import torch
import torch.nn as nn
vol = torch.tensor([0, 1])
embed = nn.Embedding(2, 4)
vol_tensor = embed(vol)

        转为词向量的结果为:

tensor([[ 0.5261,  1.7266, -0.4697,  0.2236],
        [ 1.0151,  1.5949,  0.9423, -1.3259]], grad_fn=<EmbeddingBackward0>)

         接着在对其进行线性转换:

linear1 = nn.Linear(4,4)
linear2 = nn.Linear(4,4)
linear3 = nn.Linear(4,4)

        Q,K,V的结果为

tensor([[ 0.8610, -0.4681,  1.0204, -0.9113],
        [-0.1582,  0.4929, -0.1701, -1.1226]], grad_fn=<AddmmBackward0>)

tensor([[ 0.0797,  0.9090,  0.8206, -0.2743],
        [-0.2588,  0.9723,  0.8719,  0.1857]], grad_fn=<AddmmBackward0>)

tensor([[ 1.1230,  0.3089,  0.8571,  0.3893],
        [ 0.9962, -0.4166,  0.2556, -0.2005]], grad_fn=<AddmmBackward0>)

        接下来就是计算Attention Score(注意力分数),首先是计算每个词向量之间的分数,计算公式为:

Attention = Q\cdot K^{T}

        代码为:

Kt = K.permute(1,0)
attention = torch.matmul(Q, Kt)

        attention结果为

tensor([[0.7304, 0.0426],
        [0.6037, 0.1634]], grad_fn=<MmBackward0>)

        这个结果表示词向量之间的关系,就是Machine-Machine,Machine-Learning,Learning-Machine,Learning-Learning的分数,分数越大,关系越密切。在做好以上的步骤后,将结果除以\sqrt{d_{k}},其中d_{k}是K向量的维度。这个作用是网络在求梯度的时候更加稳定。

d = math.sqrt(k.shape[0])
attention = attention/d
tensor([[0.3652, 0.0213],
        [0.3019, 0.0817]], grad_fn=<DivBackward0>)

        接着在对结果进行一个softmax的处理,为了方便后续处理。

attention = torch.softmax(attention, dim=-1)
tensor([[0.5851, 0.4149],
        [0.5548, 0.4452]], grad_fn=<SoftmaxBackward0>)

        接着再将attention的值和V相乘,对于分数高的位置,说明网络将注意力放到了它们的身上,分数越小,这些词的关联就越小。

attention = torch.matmul(attention, V)
tensor([[ 1.0704,  0.0079,  0.6076,  0.1446],
        [ 1.0666, -0.0141,  0.5894,  0.1267]], grad_fn=<MmBackward0>)

        以上的公式为

                Attention = softmax(\frac{Q\cdot K^{T}}{\sqrt{d_{k}}})\cdot V

 

  • 3
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值