多模态中的交叉注意力Cross Attentionon

CA作为多模态融合的一个重要组成部分,它通过注意力机制在不同模块之间建立联系,促进信息的交流和整合,从而提升了模型处理复杂任务的能力。

使用它需要明白以下几点要求:

  1. 两个序列必须具有相同的维度。(因为Query和Key要做点积累,)

  2. 两个序列可以是不同的模态(如文本、图像)。

  3. 一个序列作为输入的Query,定义了输出的序列长度,另一个序列作为输入的Key和Value。

在这里插入图片描述

具体地说,对于一个文本序列和一个图像序列:

  1. 文本通过一个Transformer编码器处理,输出作为查询向量Query。

  2. 图像通过CNN处理,输出经过线性变换生成键Key和值向量Value。

  3. 计算文本查询向量Query与图像键向量Key的点积,得到注意力分数Attention Score。

  4. 使用这些分数对图像的值向量Value进行加权,生成最终输出。

如上图所示。

代码实现(响应读者需求添加自2025.3.31 6:26)

该段代码是笔者在YoloWorld中对文本模态和图像模态进行交叉注意力计算的代码,标记出了大部分张量的维度。
基本参数:x为视觉模态的输入,text_embedding为文本模态的输入。
PS:暂时未进行修缮和简化,等后续有空进行

class VLCrossAttention(nn.Module):
    def __init__(self, in_channels, emb_dim, att_dropout=0.0, aropout=0.0):
        super(VLCFUModule, self).__init__()
        self.emb_dim = emb_dim
        self.scale = emb_dim ** -0.5

        self.proj_in = nn.Conv2d(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)

        self.Wq = nn.Linear(emb_dim, emb_dim)
        self.Wk = nn.Linear(emb_dim, emb_dim)
        self.Wv = nn.Linear(emb_dim, emb_dim)

        self.proj_out = nn.Conv2d(emb_dim, in_channels, kernel_size=1, stride=1, padding=0)

    def forward(self, x, text_embedding, pad_mask=None):
        '''
        :param x: [batch_size, c, h, w]
        :param text_embedding: shape([bs, 7, 512])
        :param pad_mask: [batch_size, seq_len, seq_len]
        :return:
        '''
        b, c, h, w = x.shape    # stage3: x.shape=([bs, 256, 40, 40])

        x = self.proj_in(x)   # [batch_size, c, h, w] = [bs, 1024, 40, 40]
        x = rearrange(x, 'b c h w -> b (h w) c')   # [batch_size, h*w, emb_dim] = [bs, 1600, 1024]

        Q = self.Wq(x)  # [batch_size, h*w, emb_dim] = [bs, 1600, 1024]
        K = self.Wk(text_embedding)  # [batch_szie, seq_len, emb_dim] = [bs, 7, 1024]
        V = self.Wv(text_embedding)  # 同K
        # K = self.Wk(context)  # [batch_szie, seq_len, emb_dim] = [3, 5, 512]
        # V = self.Wv(context)

        # [batch_size, h*w, seq_len]
        att_weights = torch.einsum('bid,bjd -> bij', Q, K)  # [bs, 1600, 1600]
        att_weights = att_weights * self.scale

        if pad_mask is not None:
            # [batch_size, h*w, seq_len]
            att_weights = att_weights.masked_fill(pad_mask, -1e9)

        att_weights = F.softmax(att_weights, dim=-1)
        out = torch.einsum('bij, bjd -> bid', att_weights, V)   # [batch_size, h*w, emb_dim]

        out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w)   # [batch_size, c, h, w]
        out = self.proj_out(out)   # [batch_size, c, h, w]

        return out
        # return out, att_weights
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

北上ing

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值