Self-Attention、Multi-head Self-Attention

个人笔记

讲的太好了,一听就懂!视频链接

一、 理论

1. Self-Attention、Multi-head Self-Attention最终效果:  

输入:X1  X2      ------self attention------   输出 Y1 Y2   

四者shape相同;

Y1是X1  X2不同权重的加权和;

Y2是X1  X2不同权重的加权和;

2. 计算过程

                  a1  a2  向量            WQ  WK WV 矩阵

shape        1,dmodel                 dmodel,dk 

计算公式如下:

  • 第一步:求取q k v

多个a向量拼接成矩阵;矩阵相乘并行运算速度快

 

 

  • 第二步: 求取权重系数

  • 第三步:加权相加

 

 

3. Multi-head Self-Attention

n个头,就有n组 WQ  WK WV 矩阵;

相较于一个头WQ  WK WV 行数不变,列缩减为原来的n分之一

最终得到n组 q k v

同理,相较于1个头,q k v 行数不变,列缩减为原来的n分之一

假设对a1 a2进行    Multi-head Self-Attention,头数n=2

  • 第一步:求取每一组的q k v

 

 

  • 第二步:

对每个组单独进行Self-Attention(两组互不影响)

 

  • 第三步:拼接

  • 第四步 :融合

 

二 、代码实现:

class Attention(nn.Module):   # 多头注意力机制
    def __init__(self,
                 dim,                   # 输入token的dim     如512
                 num_heads=8,           # 8个头
                 qkv_bias=False,        #偏置
                 qk_scale=None,         # VIT中为None 不用管
                 attn_drop_ratio=0.,
                 proj_drop_ratio=0.):
        super(Attention, self).__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads     # 多头其实就是分组计算再合并
        self.scale = qk_scale or head_dim ** -0.5     # 也就是公式中的  根号K
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)   # qkv三个矩阵都是512,512 合并在一起512,512*3
                                                            # qkv计算通过全连接实现的
        self.attn_drop = nn.Dropout(attn_drop_ratio)
        self.proj = nn.Linear(dim, dim)               # 完成自注意计算之后得到两个向量,   还要经过一层全连接映射
        self.proj_drop = nn.Dropout(proj_drop_ratio)

    def forward(self, x):    # 牢记输入: (批量,单词数,维度)===(B, N, C)
        # [batch_size, num_patches + 1, total_embed_dim]
        B, N, C = x.shape      # num_patches相当于输入小图片的个数,也就是单词的个数, 因为最开始要加一个标签分类,所以是 num_patches + 1
                               # total_embed_dim也就是dim,也就是单词的维度,eg:512
        # qkv(): -> [batch_size, num_patches + 1, 3 * total_embed_dim]
        # reshape: -> [batch_size, num_patches + 1, 3, num_heads, embed_dim_per_head]  # -3维度代表 qkv  -2维度代表不同的头
        # permute: -> [3, batch_size, num_heads, num_patches + 1, embed_dim_per_head]  # 2维度代表不同的头  0维度代表 qkv
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)   #求取Q K V 多维矩阵相乘,只需要最后两个维度匹配即可
        #
        q, k, v = qkv[0], qkv[1], qkv[2]  # 取出 q k v
        # q, k, v维度: [batch_size, num_heads, num_patches + 1, embed_dim_per_head]

        attn = (q @ k.transpose(-2, -1)) * self.scale  # K要转置,交换最后两个维度,Q K才能相乘
        # 只保证矩阵最后两个维度满足矩阵乘法要求即可  前两个维度不会变[batch_size, num_heads, ... , ...]
        # 这个就是公式 Q * K转置 / 根号d
        # transpose: -> [batch_size, num_heads, embed_dim_per_head, num_patches + 1]-------K转置后的维度
        # [batch_size, num_heads, num_patches + 1, num_patches + 1]---------(Q * K转置 / 根号K)结果的维度

        attn = attn.softmax(dim=-1)   #就是最后一个维度num_patches + 1做softmax
        attn = self.attn_drop(attn)   #

        #  multiply后的shape:  [batch_size, num_heads, num_patches + 1, embed_dim_per_head]
        # transpose后的shape:  [batch_size, num_patches + 1, num_heads, embed_dim_per_head]
        #  reshape后的shape:   [batch_size, num_patches + 1, total_embed_dim]
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)

        x = self.proj(x)  # [batch_size, num_patches + 1, total_embed_dim] * [total_embed_dim, total_embed_dim]
                          #  最终shape:  [batch_size, num_patches + 1, total_embed_dim]
        x = self.proj_drop(x)
        return x

上述代码说明:

num_heads=1就是Self-Attention

num_heads>1就是Multi-head Self-Attention

输出:(batch,seq_len,dim)---------------输出:(batch,seq_len,dim)

其实就是全连接,夹杂着做各种shape变换。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值