self-attention(自注意力) 和 cross-attetion(交叉注意力) 中的差异

self attention(自注意力)

在这里插入图片描述
来源: https://arxiv.org/pdf/1706.03762

计算实现:
1.计算出 score = Q K T \text{score} =QK^T score=QKT
2.计算 attention = s o f t m a x ( score ) \text{attention}=softmax(\text{score}) attention=softmax(score)
3. 计算 weighted = a t t e n t i o n ∗ V \text{weighted}=attention*V weighted=attentionV

1.计算复杂度是
O ( L 2 ) O(L^2) O(L2)
2.因为需要计算 LXL 的 注意力矩阵

softmax ( Q K T d ) \text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) softmax(d QKT)

完整公式

self_attention ( Q , K , V ) = softmax ( Q K T d ) \text{self\_attention}(Q, K, V)=\text{softmax}\left(\frac{QK^T}{\sqrt{d}}\right) self_attention(Q,K,V)=softmax(d QKT)
在这里插入图片描述
来源:https://arxiv.org/pdf/2009.14794

代码实现:

import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, input_dim):
        super(SelfAttention, self).__init__()
        self.input_dim = input_dim
        # 定义线性变换层,用于计算查询、键和值
        self.query = nn.Linear(input_dim, input_dim)  # [batch_size, seq_length, input_dim]
        self.key = nn.Linear(input_dim, input_dim)    # [batch_size, seq_length, input_dim]
        self.value = nn.Linear(input_dim, input_dim)  # [batch_size, seq_length, input_dim]
        self.softmax = nn.Softmax(dim=2)  # 注意力权重的softmax函数,沿着最后一个维度进行

    def forward(self, x):  # x的形状为 (batch_size, seq_length, input_dim)
        queries = self.query(x)  # 计算查询矩阵
        keys = self.key(x)       # 计算键矩阵
        values = self.value(x)   # 计算值矩阵

        # 计算注意力得分
        score = torch.bmm(queries, keys.transpose(1, 2)) / (self.input_dim ** 0.5)
        attention = self.softmax(score)  # 对得分应用softmax函数得到注意力权重
        weighted = torch.bmm(attention, values)  # 使用注意力权重加权值矩阵
        return weighted  # 返回加权后的值矩阵

Cross Attention(交叉注意力)

在这里插入图片描述
这张图片展示了交叉注意力模块的工作原理。

交叉注意力模块

  • 输入:

    • “What?”:这是表示“内容”的输入序列,包含值(Value,(V))和键(Key,(K))。
    • “Where?”:这是表示“位置”的输入序列,包含查询(Query,(Q))。
  • 计算过程:

    • 从“内容”输入序列中提取出值 (V) 和键 (K)。
    • 从“位置”输入序列中提取出查询 (Q)。
    • 计算查询 (Q) 和键 (K) 的点积,得到注意力能量(Attention energy)。
    • 将注意力能量除以 (\sqrt{C/h}),其中 (C) 是键的维度,(h) 是注意力头的数量,用以进行缩放。
    • 对缩放后的注意力能量应用 softmax 函数,得到注意力权重。
    • 将注意力权重应用到值 (V) 上,得到输出上下文(Output context)。

数学公式:

Cross_attention ( Q , K , V ) = Softmax ( Q K T C / h ) ⋅ V   \text{Cross\_attention}(Q, K, V) = \text{Softmax}\left(\frac{QK^T}{\sqrt{C/h}}\right) \cdot V \ Cross_attention(Q,K,V)=Softmax(C/h QKT)V 

  • 解释:
    • ( Q ):查询矩阵。
    • ( K ):键矩阵。
    • ( V ):值矩阵。
    • (\text{Softmax}):softmax 函数,用于将注意力能量转换为概率分布。
    • ( \sqrt{C/h} ):缩放因子,控制注意力能量的大小。

代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CrossAttention(nn.Module):
    def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        # 定义线性变换层,用于计算查询、键和值
        self.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
        self.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads  # 注意力头的数量
        self.d_head = d_embed // n_heads  # 每个注意力头的维度
    
    def forward(self, x, y):
        # x (潜在表示): (batch_size, seq_len_q, dim_q)
        # y (上下文): (batch_size, seq_len_kv, dim_kv) = (batch_size, 77, 768)

        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        
        # 将每个查询的嵌入向量划分为多个头,确保 d_heads * n_heads = dim_q
        interim_shape = (batch_size, -1, self.n_heads, self.d_head)
        
        # 计算查询矩阵 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, dim_q)
        q = self.q_proj(x)
        # 计算键矩阵 (batch_size, seq_len_kv, dim_kv) -> (batch_size, seq_len_kv, dim_q)
        k = self.k_proj(y)
        # 计算值矩阵 (batch_size, seq_len_kv, dim_kv) -> (batch_size, seq_len_kv, dim_q)
        v = self.v_proj(y)

        # 将查询矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, h, dim_q / h) -> (batch_size, h, seq_len_q, dim_q / h)
        q = q.view(interim_shape).transpose(1, 2) 
        # 将键矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_kv, dim_q) -> (batch_size, seq_len_kv, h, dim_q / h) -> (batch_size, h, seq_len_kv, dim_q / h)
        k = k.view(interim_shape).transpose(1, 2) 
        # 将值矩阵重塑并转置以匹配注意力头 (batch_size, seq_len_kv, dim_q) -> (batch_size, seq_len_kv, h, dim_q / h) -> (batch_size, h, seq_len_kv, dim_q / h)
        v = v.view(interim_shape).transpose(1, 2) 
        
        # 计算注意力得分 (batch_size, h, seq_len_q, dim_q / h) @ (batch_size, h, dim_q / h, seq_len_kv) -> (batch_size, h, seq_len_q, seq_len_kv)
        weight = q @ k.transpose(-1, -2)
        
        # 缩放注意力得分 (batch_size, h, seq_len_q, seq_len_kv)
        weight /= math.sqrt(self.d_head)
        
        # 对注意力得分应用softmax函数 (batch_size, h, seq_len_q, seq_len_kv)
        weight = F.softmax(weight, dim=-1)
        
        # 计算加权后的值矩阵 (batch_size, h, seq_len_q, seq_len_kv) @ (batch_size, h, seq_len_kv, dim_q / h) -> (batch_size, h, seq_len_q, dim_q / h)
        output = weight @ v
        
        # 将输出矩阵转置回原始形状 (batch_size, h, seq_len_q, dim_q / h) -> (batch_size, seq_len_q, h, dim_q / h)
        output = output.transpose(1, 2).contiguous()
        
        # 将输出矩阵重塑回原始形状 (batch_size, seq_len_q, h, dim_q / h) -> (batch_size, seq_len_q, dim_q)
        output = output.view(input_shape)
        
        # 应用最后的线性变换 (batch_size, seq_len_q, dim_q) -> (batch_size, seq_len_q, dim_q)
        output = self.out_proj(output)

        # 返回最终的输出 (batch_size, seq_len_q, dim_q)
        return output

代码来源
https://github.com/hkproj/pytorch-stable-diffusion/blob/main/sd/attention.py

  • 10
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
seq2seq-attention是指在seq2seq模型引入了注意力机制(Attention)。在传统的seq2seq模型,编码器将输入序列转化为一个固定长度的向量,然后解码器将这个向量解码成输出序列。而在seq2seq-attention模型,解码器在每个时间步都会根据输入序列的不同部分给予不同的注意力权重,从而更加关注与当前时间步相关的输入信息。这样可以提高模型对输入序列的理解能力,进而提升预测的准确性。引入注意力机制后,seq2seq-attention模型在翻译、文本摘要和问答等任务上有着更好的表现。\[1\]\[2\] #### 引用[.reference_title] - *1* [NLP自然语言处理之RNN--LSTM--GRU--seq2seq--attention--self attetion](https://blog.csdn.net/weixin_41097516/article/details/103174768)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [自注意力机制(Self-Attention):从Seq2Seq模型到一般RNN模型](https://blog.csdn.net/qq_24178985/article/details/118683144)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^insertT0,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值