attention注意力机制【对应图的代码讲解】

题目

'''
Description: attention注意力机制
Autor: 365JHWZGo
Date: 2021-12-14 17:06:11
LastEditors: 365JHWZGo
LastEditTime: 2021-12-14 22:23:54
'''

注意力机制三步式+分步代码讲解

在这里插入图片描述

导入库

import torch 
import torch.nn as nn
import torch.nn.functional as F

Attn

class Attn(nn.Module):
    def __init__(self,query_size,key_size,value_size1,value_size2):
        super(Attn,self).__init__()
        self.query_size = query_size
        self.key_size = key_size
        self.value_size1 = value_size1
        self.value_size2 = value_size2
        
        self.attn = nn.Linear(self.query_size+self.key_size,value_size1)
    
    def forward(self,q,k,v):
        
        # attn_weights=(1,32)
        attn_weights = F.softmax(self.attn(torch.concat((q[0],k[0]),1)),dim=1)
        # attn_weights.unsqueeze(0)=(1,1,32)
        # v=(1,32,64)
        # attn_applied=(1,1,64)
        output = torch.bmm(attn_weights.unsqueeze(0),v)
        
        return output,attn_weights

attn函数是将合成【Query|Key】,进行列合并
f ( Q , K ) = W a [ Q , K ] f(Q,K) = W_a[Q,K] f(Q,K)=Wa[Q,K]

attn_weights的结果对应于a1,a2,a3…
在这里插入图片描述
output是计算Attention Value,bmm相当于a1value1+a2value2+…【矩阵乘法】
在这里插入图片描述

if __name__ == "__main__":
    query_size = 32
    key_size = 32
    value_size1 = 32
    value_size2 = 64
    
    attn = Attn(query_size, key_size, value_size1, value_size2)
    Q = torch.randn(1,1,32)
    K = torch.randn(1,1,32)
    V = torch.randn(1,32,64)
    out = attn(Q, K ,V)
    print(out[0])
    print(out[1])

运行结果

在这里插入图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

365JHWZGo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值