CA作为多模态融合的一个重要组成部分,它通过注意力机制在不同模块之间建立联系,促进信息的交流和整合,从而提升了模型处理复杂任务的能力。
使用它需要明白以下几点要求:
-
两个序列必须具有相同的维度。(因为Query和Key要做点积累,)
-
两个序列可以是不同的模态(如文本、图像)。
-
一个序列作为输入的Query,定义了输出的序列长度,另一个序列作为输入的Key和Value。
具体地说,对于一个文本序列和一个图像序列:
-
文本通过一个Transformer编码器处理,输出作为查询向量Query。
-
图像通过CNN处理,输出经过线性变换生成键Key和值向量Value。
-
计算文本查询向量Query与图像键向量Key的点积,得到注意力分数Attention Score。
-
使用这些分数对图像的值向量Value进行加权,生成最终输出。
如上图所示。
代码实现(响应读者需求添加自2025.3.31 6:26)
该段代码是笔者在YoloWorld中对文本模态和图像模态进行交叉注意力计算的代码,标记出了大部分张量的维度。
基本参数:x为视觉模态的输入,text_embedding为文本模态的输入。
PS:暂时未进行修缮和简化,等后续有空进行
class VLCrossAttention(nn.Module):
def __init__(self, in_channels, emb_dim, att_dropout=0.0, aropout=0.0):
super(VLCFUModule, self).__init__()
self.emb_dim = emb_dim
self.scale = emb_dim ** -0.5
self.proj_in = nn.Conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)
self.Wq = nn.Linear(emb_dim, emb_dim)
self.Wk = nn.Linear(emb_dim, emb_dim)
self.Wv = nn.Linear(emb_dim, emb_dim)
self.proj_out = nn.Conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x, text_embedding, pad_mask=None):
'''
:param x: [batch_size, c, h, w]
:param text_embedding: shape([bs, 7, 512])
:param pad_mask: [batch_size, seq_len, seq_len]
:return:
'''
b, c, h, w = x.shape # stage3: x.shape=([bs, 256, 40, 40])
x = self.proj_in(x) # [batch_size, c, h, w] = [bs, 1024, 40, 40]
x = rearrange(x, 'b c h w -> b (h w) c') # [batch_size, h*w, emb_dim] = [bs, 1600, 1024]
Q = self.Wq(x) # [batch_size, h*w, emb_dim] = [bs, 1600, 1024]
K = self.Wk(text_embedding) # [batch_szie, seq_len, emb_dim] = [bs, 7, 1024]
V = self.Wv(text_embedding) # 同K
# K = self.Wk(context) # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
# V = self.Wv(context)
# [batch_size, h*w, seq_len]
att_weights = torch.einsum('bid,bjd -> bij', Q, K) # [bs, 1600, 1600]
att_weights = att_weights * self.scale
if pad_mask is not None:
# [batch_size, h*w, seq_len]
att_weights = att_weights.masked_fill(pad_mask, -1e9)
att_weights = F.softmax(att_weights, dim=-1)
out = torch.einsum('bij, bjd -> bid', att_weights, V) # [batch_size, h*w, emb_dim]
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w) # [batch_size, c, h, w]
out = self.proj_out(out) # [batch_size, c, h, w]
return out
# return out, att_weights