Cross Attention(XATTN )pytorch实现

XATTN 是 “Cross Attention” 的缩写,表示交叉注意力机制。这是一种在多模态模型中常用的机制,用于在不同模态(例如,视觉和文本)之间建立联系和融合信息。

交叉注意力机制(Cross Attention)

交叉注意力机制是 Transformer 中的一种变体,通常用于多模态任务,例如视觉问答、图像字幕生成等。它的主要作用是让一个模态(如文本)关注并融合另一个模态(如图像)的信息,从而实现更好的理解和生成。

基本概念
  1. Query、Key、Value

    • Query(查询):来自一个模态的输入向量。
    • Key(键)和 Value(值):来自另一个模态的输入向量。
  2. 计算注意力权重

    • 使用 Query 和 Key 计算注意力权重,表示 Query 对每个 Key 的相关性。
    • 常用的注意力函数是点积注意力(Scaled Dot-Product Attention):
      Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
      其中, Q Q Q 是 Query, K K K 是 Key, V V V 是 Value, d k d_k dk 是 Key 的维度。
  3. 加权求和

    • 使用计算出的注意力权重对 Value 进行加权求和,得到融合后的表示。

交叉注意力的应用

在多模态任务中,交叉注意力机制允许模型在处理文本时参考图像信息,或者在处理图像时参考文本信息。例如:

  1. 图像字幕生成

    • 图像特征作为 Key 和 Value,文本特征作为 Query,通过交叉注意力机制生成描述图像的文本。
  2. 视觉问答

    • 问题文本特征作为 Query,图像特征作为 Key 和 Value,通过交叉注意力机制生成答案。

代码实现


import torch
import torch.nn.functional as F

# 假设文本特征 T 和图像特征 I
T = torch.randn(32, 10, 512)  # (batch_size, text_seq_len, feature_dim)
I = torch.randn(32, 20, 512)  # (batch_size, image_seq_len, feature_dim)

# 计算 Query, Key, Value
Q = T  # Query 来自文本特征,形状 (batch_size, text_seq_len, d_k)
K = I  # Key 来自图像特征,形状 (batch_size, image_seq_len, d_k)
V = I  # Value 来自图像特征,形状 (batch_size, image_seq_len, d_v)

# 获取特征维度
d_k = Q.size(-1)  # d_k 是 Query 和 Key 的特征维度

# 计算注意力得分
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)

# 加权求和值
cross_attention_output = torch.matmul(attention_weights, V)

# 输出形状
print(cross_attention_output.shape)  # 输出形状为 (batch_size, text_seq_len, d_v)


文本特征作为 Query,图像特征作为 Key 和 Value,通过交叉注意力机制计算得到融合后的表示。

Query来自文本特征,Key和Value来自图像特征。让我们逐步分析为什么输出的形状是 (batch_size, text_seq_len, d_v)

代码分析

  1. 输入张量:

    • T 是文本特征,形状为 (batch_size, text_seq_len, feature_dim)
    • I 是图像特征,形状为 (batch_size, image_seq_len, feature_dim)
  2. Query, Key, Value 的选择:

    • Q = T:Query来自文本特征,其形状为 (batch_size, text_seq_len, d_k)
    • K = I:Key来自图像特征,其形状为 (batch_size, image_seq_len, d_k)
    • V = I:Value来自图像特征,其形状为 (batch_size, image_seq_len, d_v)
  3. 计算注意力得分:

    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    • scores 的形状为 (batch_size, text_seq_len, image_seq_len),因为它是通过将 (batch_size, text_seq_len, d_k)Q(batch_size, d_k, image_seq_len)K 转置相乘得到的。
  4. 计算注意力权重:

    attention_weights = F.softmax(scores, dim=-1)
    
    • attention_weights 的形状为 (batch_size, text_seq_len, image_seq_len),因为对 image_seq_len 维度进行了 softmax 计算。
  5. 加权求和值(输出):

    cross_attention_output = torch.matmul(attention_weights, V)
    
    • 这里,attention_weights 的形状是 (batch_size, text_seq_len, image_seq_len)V 的形状是 (batch_size, image_seq_len, d_v)
    • 矩阵乘法后,cross_attention_output 的形状是 (batch_size, text_seq_len, d_v)
    • 这意味着对于每个文本序列中的每个词,您都计算了来自图像序列中所有元素的加权和,因而输出的序列长度是 text_seq_len

总结

输出的形状是 (batch_size, text_seq_len, d_v) 是因为在跨模态注意力机制中,文本特征的每个词(Query)通过注意力机制与图像特征(Key和Value)进行交互,得到加权求和的结果,因此输出的序列长度保持为 text_seq_len

### nn.CrossAttention 代码详解及实现原理 #### 背景介绍 在神经网络领域,尤其是自然语言处理(NLP),Transformer 架构中的 Cross Attention 是一种用于捕捉不同序列间依赖关系的重要机制。Cross Attention 基于 Self Attention 发展而来,允许模型关注来自两个不同输入序列的信息流。 #### 实现细节 对于 PyTorch 中 `nn.MultiheadAttention` 类而言,虽然官方文档并没有直接提供名为 `nn.CrossAttention` 的类,但是通过配置参数可以轻松实现跨注意力功能。具体来说,当查询(query)来源于一个序列而键(key)和值(value)来自于另一个不同的序列时,则构成了所谓的 "cross attention"[^1]。 以下是基于 PyTorch 的多头交叉注意层的一个简化版本: ```python import torch.nn as nn import math class CrossAttention(nn.Module): def __init__(self, embed_dim, num_heads): super(CrossAttention, self).__init__() self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) def forward(self, query, key_value_pair): """ Args: query: Tensor of shape (L, N, E), where L is target sequence length, N is batch size and E is embedding dimension. key_value_pair: Tuple containing two tensors both with shapes like the 'query' tensor but representing source sequences. Returns: attn_output: Output from multi-head cross-attention layer. """ key, value = key_value_pair # Perform scaled dot-product attention between queries and keys/values pairs attn_output, _ = self.multihead_attn(query=query, key=key, value=value) return attn_output ``` 这段代码定义了一个简单的 `CrossAttention` 层,其中包含了对给定的查询向量以及键/值对应用多头注意力机制的过程。这里的关键在于如何设置输入数据——即让查询来自目标端(通常是解码器侧),而键和值则取自源端(编码器侧)。这正是实现了从一个序列到另一序列之间的交互学习[^3]。 #### 数学表达式 根据公式\[Attention(Q,K,V)=softmax\left (\frac {QK^{T}}{\sqrt{d_{k}}} \right )V\],可以看到,在计算过程中引入了缩放因子 \( \sqrt{d_k}\) 来稳定梯度更新过程,并最终得到加权求和后的输出矩阵 V。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值