Cross Attention(XATTN )pytorch实现

XATTN 是 “Cross Attention” 的缩写,表示交叉注意力机制。这是一种在多模态模型中常用的机制,用于在不同模态(例如,视觉和文本)之间建立联系和融合信息。

交叉注意力机制(Cross Attention)

交叉注意力机制是 Transformer 中的一种变体,通常用于多模态任务,例如视觉问答、图像字幕生成等。它的主要作用是让一个模态(如文本)关注并融合另一个模态(如图像)的信息,从而实现更好的理解和生成。

基本概念
  1. Query、Key、Value

    • Query(查询):来自一个模态的输入向量。
    • Key(键)和 Value(值):来自另一个模态的输入向量。
  2. 计算注意力权重

    • 使用 Query 和 Key 计算注意力权重,表示 Query 对每个 Key 的相关性。
    • 常用的注意力函数是点积注意力(Scaled Dot-Product Attention):
      Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
      其中, Q Q Q 是 Query, K K K 是 Key, V V V 是 Value, d k d_k dk 是 Key 的维度。
  3. 加权求和

    • 使用计算出的注意力权重对 Value 进行加权求和,得到融合后的表示。

交叉注意力的应用

在多模态任务中,交叉注意力机制允许模型在处理文本时参考图像信息,或者在处理图像时参考文本信息。例如:

  1. 图像字幕生成

    • 图像特征作为 Key 和 Value,文本特征作为 Query,通过交叉注意力机制生成描述图像的文本。
  2. 视觉问答

    • 问题文本特征作为 Query,图像特征作为 Key 和 Value,通过交叉注意力机制生成答案。

代码实现


import torch
import torch.nn.functional as F

# 假设文本特征 T 和图像特征 I
T = torch.randn(32, 10, 512)  # (batch_size, text_seq_len, feature_dim)
I = torch.randn(32, 20, 512)  # (batch_size, image_seq_len, feature_dim)

# 计算 Query, Key, Value
Q = T  # Query 来自文本特征,形状 (batch_size, text_seq_len, d_k)
K = I  # Key 来自图像特征,形状 (batch_size, image_seq_len, d_k)
V = I  # Value 来自图像特征,形状 (batch_size, image_seq_len, d_v)

# 获取特征维度
d_k = Q.size(-1)  # d_k 是 Query 和 Key 的特征维度

# 计算注意力得分
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

# 计算注意力权重
attention_weights = F.softmax(scores, dim=-1)

# 加权求和值
cross_attention_output = torch.matmul(attention_weights, V)

# 输出形状
print(cross_attention_output.shape)  # 输出形状为 (batch_size, text_seq_len, d_v)


文本特征作为 Query,图像特征作为 Key 和 Value,通过交叉注意力机制计算得到融合后的表示。

Query来自文本特征,Key和Value来自图像特征。让我们逐步分析为什么输出的形状是 (batch_size, text_seq_len, d_v)

代码分析

  1. 输入张量:

    • T 是文本特征,形状为 (batch_size, text_seq_len, feature_dim)
    • I 是图像特征,形状为 (batch_size, image_seq_len, feature_dim)
  2. Query, Key, Value 的选择:

    • Q = T:Query来自文本特征,其形状为 (batch_size, text_seq_len, d_k)
    • K = I:Key来自图像特征,其形状为 (batch_size, image_seq_len, d_k)
    • V = I:Value来自图像特征,其形状为 (batch_size, image_seq_len, d_v)
  3. 计算注意力得分:

    scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    
    • scores 的形状为 (batch_size, text_seq_len, image_seq_len),因为它是通过将 (batch_size, text_seq_len, d_k)Q(batch_size, d_k, image_seq_len)K 转置相乘得到的。
  4. 计算注意力权重:

    attention_weights = F.softmax(scores, dim=-1)
    
    • attention_weights 的形状为 (batch_size, text_seq_len, image_seq_len),因为对 image_seq_len 维度进行了 softmax 计算。
  5. 加权求和值(输出):

    cross_attention_output = torch.matmul(attention_weights, V)
    
    • 这里,attention_weights 的形状是 (batch_size, text_seq_len, image_seq_len)V 的形状是 (batch_size, image_seq_len, d_v)
    • 矩阵乘法后,cross_attention_output 的形状是 (batch_size, text_seq_len, d_v)
    • 这意味着对于每个文本序列中的每个词,您都计算了来自图像序列中所有元素的加权和,因而输出的序列长度是 text_seq_len

总结

输出的形状是 (batch_size, text_seq_len, d_v) 是因为在跨模态注意力机制中,文本特征的每个词(Query)通过注意力机制与图像特征(Key和Value)进行交互,得到加权求和的结果,因此输出的序列长度保持为 text_seq_len

  • 4
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值