Transformer代码学习——手写多头注意力(网课)

原作者视频链接:

【研1基本功 (真的很简单)注意力机制】手写多头注意力机制_哔哩哔哩_bilibili先看文档哈 https://dwexzknzsh8.feishu.cn/docx/VkYud3H0zoDTrrxNX5lce0S4nDh?from=from_copylink慢慢更新,一周更完transformer同步更新RLHF、LoRA, 视频播放量 77272、弹幕量 52、点赞数 1996、投硬币枚数 1202、收藏人数 7088、转发人数 391, 视频作者 happy魇, 作者简介 Enjoy life | 欢迎来到我的主页!,相关视频:注意力机制的本质|Self-Attention|Transformer|QKV矩阵,【深度学习缝合模块】废材研究生自救指南!12个最新模块缝合模块创新!-CV、注意力机制、SE模块,膜拜!浙大教授竟把Transformer讲的如此简单!全套【Transformer基础】课程分享,连草履虫都能学会!再学不会UP下跪!,【可视化】Transformer中多头注意力的计算过程,24种魔改注意力机制 暴力涨点 即插即用 CNN+注意力机制,Transformer论文逐段精读【论文精读】,【官方双语】直观解释注意力机制,Transformer的核心 | 【深度学习第6章】,【研1基本功 (真的很简单)Group Query-Attention】大模型训练必备方法——bonus(位置编码讲解),视觉十分钟|通道注意力原理(SENet,CBAM,SRM,ECA,FcaNet)|计算机视觉/通道注意力/总结分享,图神经网络改进-手把手教你改代码-第1期icon-default.png?t=N7T8https://www.bilibili.com/video/BV1o2421A7Dr/

1. 注意事项

用q乘以k的转置之后,所做的softmax是对矩阵的竖排做,而不是横着做softmax

2.代码部分

import torch
from torch import nn
import torch.functional as f
import math
#%%
# 测试数据
X = torch.randn( 128, 64, 512) # Batch,Time,Dimension
print(X.shape)
#%%
# 设置multihead_attention基本参数
d_model = 512  # 映射到Q,K,V空间中有多少位
n_head = 8 # 有多少个头
#%%
class multi_head_attention(nn.Module):
    def __init__(self, d_model,n_head) -> None:
        super(multi_head_attention,self).__init__()

        self.n_head = n_head
        self.d_model = d_model
        self.w_q = nn.Linear(d_model, d_model)  # 线性层映射函数,把初始向量映射到Q,K,V(query,key,value)
        # 简单来说就是去寻找一些query去跟key,问他(key)哪些数据是跟我匹配的上的,匹配上之后,key所对应的value值进行加权组合,最终得到attention的输出
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model) # 由于是多头注意力,所以要在最后做一个组合映射(多写一个w_combine的线性映射)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.n_head   # 得到新维度
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)  # 把qkv分别丢到上面定义的三个线性映射层中,就可以得到qkv空间中的一个表示
        # 对空间表示进行切分,对我们需要得到几个头进行切分
        q = q.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3) # 把q进行维度划分,一维是batch,二维是time, 三维是n.head(分成几个头),四维是n.d(分完头之后的维度)
        k = k.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3) # 也可以说把最后一维拆成了n.head和n.d两个维度的乘积
        v = v.view(batch, time, self.n_head, n_d).permute(0, 2, 1, 3) # 做attention操作的时候head维是不能放在最后的,对最后两个维度进行处理,所以要用permute指令做一个维度变换
        # 原先的维度是0,1,2,3现在则是0,2,1,3

        score = q @ k.transpose(2, 3) / math.sqrt(n_d) # q乘以k的转置除以它的维度开根号(让方差变小) @是矩阵乘法
        # torch.tril命令-生成下三角矩阵(左下角都是1,右上角都是0)
        mask = torch.tril(torch.ones(time, time, dtype=bool))
        score = score.masked_fill(mask == 0, float("-inf")) # 把mask等于0的地方都填充为负无穷
        # 填充为负无穷的原因:softmax操作时e^-inf就是0,就相当于我们不去care后面部分的信息
        score = self.softmax(score) @ v

        # 最后把得分的格式变回来(因为之前把time维和self.n_head维度进行了旋转,现在则是要旋转回来),然后再过一个连续性函数
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)
        # contiguous()的作用是让整个矩阵序列在内存中都是连续的

        output = self.w_combine(score)
        return output

attention = multi_head_attention(d_model, n_head)
output = attention(X, X, X)
print(output, output.shape)

  • 18
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值