自注意力机制

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads  = num_heads
        self.scale      = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop  = nn.Dropout(attn_drop)
        self.proj       = nn.Linear(dim, dim)
        self.proj_drop  = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C     = x.shape
        qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v     = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

这个代码定义了一个自注意力机制(Self-Attention)模块,通常用于Transformer模型。让我们逐步分析这个Attention类的构造和它的前向传播过程。

1. __init__ 方法

__init__方法是这个类的构造函数,用于初始化这个自注意力模块的各个组件。

def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
    super().__init__()
    self.num_heads  = num_heads
    self.scale      = (dim // num_heads) ** -0.5

    self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
    self.attn_drop  = nn.Dropout(attn_drop)
    self.proj       = nn.Linear(dim, dim)
    self.proj_drop  = nn.Dropout(proj_drop)
  • dim:输入嵌入的维度,即输入向量的长度。
  • num_heads:多头注意力机制中的头数,默认为8。
  • qkv_bias:是否在生成查询(Query)、键(Key)、值(Value)向量时使用偏置(bias),默认为False。
  • attn_drop:注意力权重的dropout概率,用于防止过拟合。
  • proj_drop:在最后的投影层应用的dropout概率。

组件说明:

  • self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias):这是一个线性层,将输入x投影到查询、键和值三个子空间中。它的输出维度是 dim * 3,因为它同时生成qkv
  • self.scale = (dim // num_heads) ** -0.5:缩放因子,用于对qk的点积进行缩放,这样可以避免点积值过大导致梯度不稳定。
  • self.attn_drop = nn.Dropout(attn_drop):对注意力权重进行dropout。
  • self.proj = nn.Linear(dim, dim):对自注意力机制的输出进行线性投影。
  • self.proj_drop = nn.Dropout(proj_drop):对投影后的输出进行dropout。

2. forward 方法

forward方法定义了这个模块的前向传播过程,即输入数据如何通过这个模块进行计算并得到输出。

def forward(self, x):
    B, N, C     = x.shape
    qkv         = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
    q, k, v     = qkv[0], qkv[1], qkv[2]

    attn = (q @ k.transpose(-2, -1)) * self.scale
    attn = attn.softmax(dim=-1)
    attn = self.attn_drop(attn)

    x = (attn @ v).transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x
2.1 输入形状
  • x的形状是 (B, N, C),其中:
    • B 是批次大小(batch size)。
    • N 是序列长度(如句子的词数或图像的patch数)。
    • C 是嵌入向量的维度(即dim)。
2.2 生成 qkv
  • qkv = self.qkv(x):对输入 x 应用线性层,生成一个形状为 (B, N, 3 * C) 的张量。
  • reshape 操作将 qkv 重塑为形状 (B, N, 3, self.num_heads, C // self.num_heads)。这里3表示qkv三个部分,self.num_heads是多头注意力中的头数。
  • permute(2, 0, 3, 1, 4):将维度重新排序,得到形状 (3, B, self.num_heads, N, C // self.num_heads)
    • qkv[0]:查询向量 q,形状为 (B, self.num_heads, N, C // self.num_heads)
    • qkv[1]:键向量 k,形状相同。
    • qkv[2]:值向量 v,形状相同。
2.3 计算注意力权重
  • attn = (q @ k.transpose(-2, -1)) * self.scale:计算 qk 的点积,然后乘以缩放因子self.scale,得到注意力权重矩阵attn。其中 k.transpose(-2, -1) 表示将键向量的最后两个维度互换,以便进行矩阵乘法,最终得到一个形状为 (B, self.num_heads, N, N) 的矩阵。

  • attn = attn.softmax(dim=-1):对注意力权重矩阵的最后一个维度应用softmax函数,使其成为概率分布。

  • attn = self.attn_drop(attn):对注意力权重进行dropout,以防止过拟合。

2.4 计算注意力输出
  • x = (attn @ v):将注意力权重矩阵attn与值向量v相乘,得到加权后的输出。

  • x.transpose(1, 2).reshape(B, N, C):调整维度顺序并重塑张量,使其形状恢复为 (B, N, C)

2.5 投影和输出
  • x = self.proj(x):通过线性层对输出 x 进行投影,将其映射回原始的嵌入维度C
  • x = self.proj_drop(x):对投影后的输出进行dropout。
  • return x:返回最终的输出。

总结

这个 Attention 模块实现了一个带有多头机制的自注意力层,常用于Transformer架构中。它通过线性变换将输入数据生成查询、键、值向量,计算注意力权重,再使用这些权重对值向量进行加权求和,最终通过一个线性层投影输出结果。这个模块的目的是让每个位置(如词、图像块)根据序列中的其他位置动态地调整自己的表示,捕捉序列内部的依赖关系。

  • 20
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值