RefFormer论文精读

基本信息

摘要部分

Visual Grounding(视觉定位)旨在根据给定的自然语言表达,在图像中定位所提及的对象。近年来,基于DETR的视觉定位方法因其无需依赖额外的努力(例如预先生成的候选区域或预定义的锚框)即可直接预测目标对象的坐标而受到了广泛关注。然而,现有研究主要集中于设计更强大的多模态解码器,这些解码器通常通过随机初始化或使用语言嵌入来生成可学习的查询(queries)。这种原始的查询生成方法不可避免地增加了模型的学习难度,因为它在解码开始时没有包含任何与目标相关的信息。此外,它们在查询学习过程中仅使用最深层的图像特征,忽略了其他层特征的重要性。
为了解决这些问题,我们提出了一种名为RefFormer的新方法。它包含一个查询适应模块,该模块可以无缝集成到CLIP中,并生成参照查询(referential query)以为解码器提供先验上下文,同时还有一个任务特定的解码器。通过将参照查询整合到解码器中,我们可以有效地降低解码器的学习难度,并准确地聚焦于目标对象。此外,我们提出的查询适应模块还可以充当适配器,在无需调整骨干网络参数的情况下保留CLIP中的丰富知识。大量的实验证明了我们提出方法的有效性和效率,在五个视觉定位基准测试上均优于最先进的方法。

提出的问题

  • The queries that are inputted to the decoder in these methods are typically generated through random initialization or by utilizing linguistic embeddings.(这些解码器通常通过随机初始化或使用语言嵌入来生成可学习的查询(queries))
  • 这种与目标无关的查询不可避免地增加了解码器的学习难度。
  • 在查询学习过程中,这些方法倾向于只关注骨干网络的最深层视觉特征,而忽略了对于定位任务至关重要且存在于低层和中层特征中的纹理信息。

动机和方案

  • 我们能否为解码器生成与目标相关的参照查询(referential queries),以减轻解码器面临的学习难度?
  • 我们如何有效地将多层视觉上下文信息融入查询学习过程?
  • 考虑到 CLIP 携带着丰富的视觉-语言对齐知识,因此我们将其作为我们方法的骨干网络。
  • 提出了一种名为 RefFormer 的新方法。我们的方法整合了一个查询适应(QA)模块,用于生成参照查询,它为解码器提供了与目标相关的上下文(如图 1 (b) 所示)
  • 策略性地将 QA 模块插入到 CLIP 的不同层中,查询可以自适应地从多层图像特征图中学习目标相关信息,并逐层迭代地精炼所获取的信息。此外,我们提出的 RefFormer 还可以充当适配器(adapter),使得 CLIP 可以保持冻结(参数不更新,降低训练成本),并保留其原有的丰富知识。
  • 它采用了双向(bi-directional interaction)交互机制,通过引入少量可训练的参数来执行多模态融合,并在整个特征提取过程中通过残差连接将新的任务特定知识注入到 CLIP 中。

相关工作部分

In the object detection field, DETR presents an end-to-end object detection model that is built in an encoder-decoder transformer architecture. However, it suffers from slow training convergence. To address this issue, some follow-up works [55, 19, 49, 45, 20, 26, 24, 54] solve this issue by optimizing the learnable queries in DETR. For instance, Anchor DETR [45] directly treats 2D reference points as queries, while DAB-DETR [24] further investigates the role of queries in DETR and proposes the use of 4D anchor boxes as queries. In contrast to these model-level improvements, DN-DETR [20] introduces query denoising training to mitigate the instability of bipartite graph matching, which is further enhanced by DINO [54].
在目标检测领域,DETR 提出了一种基于编码器-解码器 Transformer 架构的端到端目标检测模型,但该模型存在训练收敛缓慢的问题。为了解决这一问题,后续研究工作 [55, 19, 49, 45, 20, 26, 24, 54] 主要通过优化 DETR 中的可学习查询(learnable queries)来加快训练收敛速度。例如,Anchor DETR [45] 直接将二维参考点作为查询,DAB-DETR [24] 则进一步探索了查询的作用,并提出使用四维锚框作为查询。与这些模型级别的改进不同,DN-DETR [20] 引入了查询去噪训练机制以缓解二分图匹配过程中的不稳定性,而该机制在 DINO [54] 中得到了进一步增强。
Additionally, similar research has been explored in other tasks [22, 13, 37]. For example, EaTR [13] formulates a video as a set of event units and treats video-specific event units as dynamic moment queries in video grounding tasks. MTR++ [37] introduces distinct learnable intention queries generated by the k-means clustering algorithm to handle trajectory prediction across different motion modes in motion prediction tasks.
此外,类似的查询优化思想也被应用于其他任务中 [22, 13, 37]。例如,EaTR [13] 将视频建模为一组事件单元,并在视频定位任务中将这些视频特定的事件单元作为动态时刻查询使用。MTR++ [37] 则通过 K-means 聚类算法生成具有差异性的可学习意图查询,用于处理多种运动模式下的轨迹预测任务。

用于目标检测的DETR在训练时收敛比较慢。其中一个可以优化的环节就是模型内部用来“寻找答案”的“可学习查询(Learnable Queries)”,优化“可学习查询”是提升现代AI模型性能和效率的一个重要方向。无论是赋予查询更明确的物理意义(如空间位置或锚框),改进训练过程的稳定性,还是将查询思想巧妙地应用到视频理解、运动预测等不同领域,其核心都是为了让模型能够更智能、更高效地从复杂数据中提取和定位所需信息。

什么是“可学习查询”?为什么要优化它?

  • 可以把“可学习查询”想象成模型内部派出的一堆“小侦探”。在目标检测任务中,每个“侦探”的任务是去图片中找到一个物体。
  • 优化可学习查询:就是想办法让这些“小侦探”变得更聪明、更高效。例如,改进它们出发前的“装备”(初始化方式)或者它们“搜寻目标的方法”(设计)。目标是让它们能更快、更准确地找到物体,从而让整个模型训练得更快、效果更好。

如何在目标检测中优化查询?赋予查询“空间感”!

  • Anchor DETR:给查询一个“起点”

    • 它让每个查询直接代表图片中的一个二维参考点(可以想象成在图片上预先“画”的点)。
    • 模型学习的是:这些查询所代表的参考点,应该在图片的哪个位置,才最有可能找到物体。
    • 效果:这样一来,查询不再是漫无目的的,而是有了明确的“空间坐标”作为起点,帮助模型更快地定位到物体。
  • DAB-DETR:让查询变成“初始猜测框”

    • 更进一步,DAB-DETR 让每个查询代表一个四维的锚框(Anchor Box)。这个锚框包含了物体可能的位置(x, y坐标)和大小(宽度、高度)的初步猜测。
    • 效果:查询直接化身为模型对物体可能在哪、可能多大的“初始蓝图”。这为模型搜索和精调物体边界框提供了一个更具体、更强的起点。

优化训练策略:稳定匹配,间接助力查询

  • DN-DETR 和 DINO:通过“去噪训练”稳定学习过程
    • 背景:DETR类模型在训练时,有一个关键步骤叫做“二分图匹配”。它负责将模型预测出的物体框与图片中真实的物体框进行配对,这样才能计算误差,指导模型学习。但这个匹配过程早期可能不太稳定(因为模型一开始预测得很乱,不知道哪个预测对应哪个真实物体)。
    • 解决方法:它们引入了“查询去噪训练”技术。简单来说,就是在训练时,故意给输入数据(比如真实的物体框坐标或查询本身)添加一些“干扰噪声”,然后要求模型学习如何从这些“带噪声的数据”中恢复出“干净的原始数据”。
    • 效果:这种“在干扰中学习”的方式,迫使模型更鲁棒,有助于稳定二分图匹配过程,从而加快整体的训练收敛速度。注意,这里主要是改进了训练的方法,而不是查询本身代表的意义或结构。

“查询”概念的应用拓展:不止于目标检测

“查询”或“探针”这种机制,不仅仅用在检测图片中的物体,它在其他AI任务中也大显身手:

  • 视频定位 (Video Grounding)

    • 任务:根据一句话描述(如“一个人在跑步机上跑步”),在一段长视频中准确找到对应的片段。
    • EaTR模型的做法:它将视频看作一系列连续的“事件单元”(比如视频里一小段连贯的动作或场景)。然后,它把这些与视频内容紧密相关的“事件单元”作为“动态时刻查询”来使用。这些查询就像可调整的“时间锚点”,帮助模型在视频中搜索并定位与文字描述最匹配的那个精确“时刻”。“动态”可能指这些查询会根据视频内容或搜索过程进行调整。
  • 运动预测 (Motion Prediction)

    • 任务:预测一个物体(如汽车、行人)接下来的运动轨迹。
    • MTR++模型的做法:现实中,物体可能有多种运动意图(如直行、左转、减速)。为了捕捉这种多样性,MTR++引入了“有差异性的可学习意图查询”。
      • 这些查询不再是通用的,而是每一个都代表一种特定的运动“意图”或“模式”(比如一个查询专门负责预测“左转”轨迹,另一个负责“直行”轨迹)。
      • 如何实现差异性:模型在训练前,会通过聚类算法(如k-means)分析大量真实的运动轨迹数据,将相似的轨迹归为一类(形成一个“模式”),然后用这些模式的特征来初始化对应的“意图查询”。
      • 效果:通过这些代表不同意图的查询,MTR++能够同时预测出物体未来多种可能的运动轨迹,每条轨迹对应一种潜在意图,这比只预测单一轨迹更加全面和实用。

figure1

特征提取部分

考虑到 CLIP 在视觉-语言对齐方面令人印象深刻的能力,我们将其作为我们方法的主干backbone,用于提取图像和文本表示,并在训练期间保持参数冻结。特征提取过程表示如下:

图片特征提取

图像编码器: 对于一个输入图像 V ∈ R H × W × 3 V \in \mathbb{R}^{H \times W \times 3} VRH×W×3,它被分割成 N N N 个不重叠的图像块(patch),每个块的大小为 P × P P \times P P×P,其中 N v = H × W P 2 N_v = \frac{H \times W}{P^2} Nv=P2H×W。接下来,这些图像块被展平(flattened)为一组向量,表示为 { x v i ∈ R 3 P 2 } i = 1 N \left\{\mathbf{x}_{v}^{i} \in \mathbb{R}^{3 P^{2}}\right\}_{i=1}^{N} {xviR3P2}i=1N。然后,这些向量通过一个线性投影层 ϕ e ( ⋅ ) \phi_e(\cdot) ϕe() 变换为 token 嵌入(token embeddings)。此外,一个分类 token x c l s ∈ R D x_{cls} \in \mathbb{R}^D xclsRD 被添加到 token 嵌入的开头。随后,位置嵌入 E v \mathbf{E}_{v} Ev 被加入,并应用层归一化(layer normalization, LN)。这个过程可以表示如下:

Z v 0 = L N ( [ x c l s ; ϕ e ( X v ) ] + E v ) \mathbf{Z}_v^0=L N\left(\left[\mathbf{x}_{c l s} ; \phi_e\left(\mathbf{X}_v\right)\right]+\mathbf{E}_{v}\right) Zv0=LN([xcls;ϕe(Xv)]+Ev)

其中 [ ; ] [;] [;] 表示连接(concatenate)操作。然后,token 序列 Z v 0 Z_v^0 Zv0 被输入到 L L L 个 transformer 层。每个 transformer 层包含两个子模块:多头自注意力(multi-head self-attention, MHSA)和多层感知器(multilayer perceptron, MLP),每个子模块之前都进行了层归一化。

Z ˉ v i = M H S A ( L N ( Z v i − 1 ) ) + Z v i − 1 , i = 1 , . . . , L (2) \bar{Z}_v^i = MHSA(LN(Z_v^{i-1})) + Z_v^{i-1}, \quad i=1,...,L \tag{2} Zˉvi=MHSA(LN(Zvi1))+Zvi1,i=1,...,L(2)
Z v i = M L P ( L N ( Z ˉ v i ) ) + Z ˉ v i (3) Z_v^i = MLP(LN(\bar{Z}_v^i)) + \bar{Z}_v^i \tag{3} Zvi=MLP(LN(Zˉvi))+Zˉvi(3)

其中 Z v i ∈ R N × D Z_v^i \in \mathbb{R}^{N \times D} ZviRN×D 表示第 i i i 个 transformer 层的输出。

文本特征提取

文本编码器: 给定一个指代表达 T T T,它首先使用小写字节对编码(lower-cased byte pair encoding)表示转换为一个词嵌入(word embeddings)序列 X t X_t Xt。词嵌入会用 [SOS][EOS] token 包围起来,生成一个长度为 N t N_t Nt 的序列。与图像编码器类似,这些 token 会与位置嵌入 E t E_t Et 相加,并通过 L L L 个 transformer 层来提取文本表示:

Z ˉ t i = M H S A ( L N ( Z t i − 1 ) ) + Z t i − 1 , i = 1 , . . . , L (4) \bar{Z}_t^i = MHSA(LN(Z_t^{i-1})) + Z_t^{i-1}, \quad i=1,...,L \tag{4} Zˉti=MHSA(LN(Zti1))+Zti1,i=1,...,L(4)
Z t i = M L P ( L N ( Z ˉ t i ) ) + Z ˉ t i (5) Z_t^i = MLP(LN(\bar{Z}_t^i)) + \bar{Z}_t^i \tag{5} Zti=MLP(LN(Zˉti))+Zˉti(5)

其中 Z t 0 = [ x s o s ; X t ; x e o s ] + E t Z_t^0 = [x_{sos}; X_t; x_{eos}] + E_t Zt0=[xsos;Xt;xeos]+Et,表示文本编码器中词嵌入层(word embeddings)的输出。

方法部分

figure2

Query Adaptation Module 查询自适应模块 (QA)

QA 模块:(如图 3 所示),该模块可以生成参照查询,为解码器提供与目标相关的上下文,从而增强解码器的 grounding 能力。重要的是,我们的方法将多级特征整合到查询学习过程中,使查询能够捕获更全面的目标对象信息,并可以逐层优化精炼。此外,QA 还可以作为适配器,使得无需对整个骨干网络的参数进行微调。

figure3

降维投影:考虑到从骨干网络的第 i i i 层获得的图像和语言表示 Z v i Z_v^i Zvi Z t i Z_t^i Zti,我们首先使用 MLP 层 ϕ v d i ( ⋅ ) \phi_{vd}^i(\cdot) ϕvdi() ϕ t d i ( ⋅ ) \phi_{td}^i(\cdot) ϕtdi() 将它们投影到较低维度的特征,以减少计算内存
F v i = ϕ v d i ( Z v i ) , F t i = ϕ t d i ( Z t i ) (6) F_v^i = \phi_{vd}^i(Z_v^i), \quad F_t^i = \phi_{td}^i(Z_t^i) \tag{6} Fvi=ϕvdi(Zvi),Fti=ϕtdi(Zti)(6)

条件聚合和多模态融合 (CAMF): 我们随机初始化 N q N_q Nq 个可学习查询 Q ∈ R N q × D l Q \in \mathbb{R}^{N_q \times D_l} QRNq×Dl,其中 D l D_l Dl 表示投影后的维度。这些查询经过专门设计,用于捕获潜在目标对象的上下文。接下来,我们将这些查询与图像特征连接起来,并将它们连同语言特征一起输入到 CAMF 块中。具体来说,CAMF 块主要由一个交叉注意力(cross-attention)层组成,该层分别以查询和图像特征 [ Q ; F v ] [Q; F_v] [Q;Fv] 和语言特征 F t F_t Ft 作为查询 Q Q Q。这种方法不仅使我们能够将表达条件融入可学习查询 Q Q Q,还可以从其他模态中提取相关信息,从而促进目标相关跨模态特征的融合。此外,我们引入了两个可学习的调节 token r v , r t ∈ R D l r_v, r_t \in \mathbb{R}^{D_l} rv,rtRDl (learnable regulation tokens,具体解释见附4)来调节(modulate)每个 QA 的最终输出。这个过程可以形式化如下:

r ˉ v , Q ˉ c i , F ˉ v i = M H C A ( [ r v ; Q i − 1 ; F v i ] , F t i , F t i ) (7) \bar{r}_v, \bar{\mathbf{Q}}_c^i, \bar{\mathbf{F}}_v^i = MHCA([\mathbf{r}_v; \mathbf{Q}^{i-1}; \mathbf{F}_v^i], \mathbf{F}_t^i, \mathbf{F}_t^i) \tag{7} rˉv,Qˉci,Fˉvi=MHCA([rv;Qi1;Fvi],Fti,Fti)(7)

Q ^ c i = L N ( Q ˉ c i ) + Q i − 1 , F ^ v i = L N ( F ˉ v i ) + F v i (8) \hat{\mathbf{Q}}_c^i = LN(\bar{\mathbf{Q}}_c^i) + \mathbf{Q}^{i-1}, \quad \hat{\mathbf{F}}_v^i = LN(\bar{\mathbf{F}}_v^i) + \mathbf{F}_v^i \tag{8} Q^ci=LN(Qˉci)+Qi1,F^vi=LN(Fˉvi)+Fvi(8)

r ˉ v , F ˉ t i = M H C A ( [ r t ; F t i ] , F v i , F v i ) , F ^ t i = L N ( F ˉ t i ) + F t i (9) \bar{r}_v, \bar{\mathbf{F}}_t^i = MHCA([\mathbf{r}_t; \mathbf{F}_t^i], \mathbf{F}_v^i, \mathbf{F}_v^i), \quad \hat{\mathbf{F}}_t^i = LN(\bar{\mathbf{F}}_t^i) + \mathbf{F}_t^i \tag{9} rˉv,Fˉti=MHCA([rt;Fti],Fvi,Fvi),F^ti=LN(Fˉti)+Fti(9)

其中 Q i − 1 Q^{i-1} Qi1 表示从前一个 QA 输出的可学习查询,而 Q 0 Q^0 Q0 是随机初始化的。符号 [ ; ] [;] [;] 表示连接concat操作,而 MHCA ( ⋅ , ⋅ , ⋅ ) \text{MHCA}(\cdot, \cdot, \cdot) MHCA(,,) LN ( ⋅ ) \text{LN}(\cdot) LN() 分别表示多头交叉注意力层和层归一化。

目标相关上下文细化 (TR): 接下来,我们将查询 Q ^ c \hat{Q}_c Q^c 和多模态增强特征图 F ^ v i \hat{F}_v^i F^vi F ^ t i \hat{F}_t^i F^ti 输入到 TR 块中。首先,我们使用聚合了条件的查询 Q ^ c \hat{Q}_c Q^c 与多模态增强图像特征图 F ^ v i \hat{F}_v^i F^vi 进行交互,以细化其中的目标相关视觉上下文。
Q v i = MHCA ( Q ^ c i , F ^ v i , F ^ v i ) , Q i = LN ( MLP ( Q v i ) ) + Q ^ c i (10) Q_v^i = \text{MHCA}(\hat{Q}_c^i, \hat{F}_v^i, \hat{F}_v^i), \quad Q^i = \text{LN}(\text{MLP}(Q_v^i)) + \hat{Q}_c^i \tag{10} Qvi=MHCA(Q^ci,F^vi,F^vi),Qi=LN(MLP(Qvi))+Q^ci(10)

此外,对于聚合了其他模态信息的特征图 F ^ v i \hat{F}_v^i F^vi F ^ t i \hat{F}_t^i F^ti,我们使用自注意力进一步增强它们的目标相关上下文语义:
r ~ v , F ~ v i = M H S A ( [ r ˉ v ; F ^ v i ] , F ^ v i , F ^ v i ) , G v i = L N ( M L P ( F ~ v i ) ) + F ^ v i (11) \tilde{r}_v, \tilde{\mathbf{F}}_v^i = MHSA([\bar{r}_v; \hat{\mathbf{F}}_v^i], \hat{\mathbf{F}}_v^i, \hat{\mathbf{F}}_v^i), \quad \mathbf{G}_v^i = LN(MLP(\tilde{\mathbf{F}}_v^i)) + \hat{\mathbf{F}}_v^i \tag{11} r~v,F~vi=MHSA([rˉv;F^vi],F^vi,F^vi),Gvi=LN(MLP(F~vi))+F^vi(11)

r ~ t , F ~ t i = M H S A ( [ r ˉ v ; F ^ t i ] , F ^ t i , F ^ t i ) , G t i = L N ( M L P ( F ~ t i ) ) + F ^ t i (12) \tilde{r}_t, \tilde{\mathbf{F}}_t^i = MHSA([\bar{r}_v; \hat{\mathbf{F}}_t^i], \hat{\mathbf{F}}_t^i, \hat{\mathbf{F}}_t^i), \quad \mathbf{G}_t^i = LN(MLP(\tilde{\mathbf{F}}_t^i)) + \hat{\mathbf{F}}_t^i \tag{12} r~t,F~ti=MHSA([rˉv;F^ti],F^ti,F^ti),Gti=LN(MLP(F~ti))+F^ti(12)

上采样投影: 最后,我们利用 MLP 将图像和语言特征的通道维度恢复到其原始大小。然后,这些特征以残差方式作为输入传递给骨干网络的下一层。在此之前,我们利用调节 token 来调节特征 G v G_v Gv G t G_t Gt,这有助于防止多模态信号 overpowering 原始信号。

Z ^ v i = ϕ v u i ( G v i × σ ( r ~ v ) ) + Z v i , Z ^ t i = ϕ t u i ( G t i × σ ( r ~ t ) ) + Z t i (13) \hat{Z}_v^i = \phi_{vu}^i (G_v^i \times \sigma(\tilde{r}_v)) + Z_v^i, \hat{Z}_t^i = \phi_{tu}^i (G_t^i \times \sigma(\tilde{r}_t)) + Z_t^i \tag{13} Z^vi=ϕvui(Gvi×σ(r~v))+Zvi,Z^ti=ϕtui(Gti×σ(r~t))+Zti(13)

其中 ϕ v u i ( ⋅ ) \phi_{vu}^i (\cdot) ϕvui() ϕ t u i ( ⋅ ) \phi_{tu}^i (\cdot) ϕtui() 表示 MLP 层,而 σ ( ⋅ ) \sigma (\cdot) σ() 表示 sigmoid 函数。

最后,通过迭代执行上述过程,查询 Q Q Q 可以逐步聚焦于目标相关上下文,并生成参照查询为解码器提供先验上下文。

上采样投影这一步是把之前经过处理(融合了图像和语言信息)的特征,通过 MLP 层,把它们的维度或大小恢复到和原始特征一样的水平。
然后,这些恢复后的特征并不会直接替换掉原始特征,而是以一种叫做“残差连接”的方式加回到骨干网络(主网络)的对应层中。这就像是把处理过的信息作为一种“补充”或“调整”加到原始信息上,保留了原始信息的底子。
关键点是“调节 token”: 在把处理过的特征加回原始特征之前,会用一个叫做“调节 token”的东西去调整modulate)这些处理过的特征 ( G v G_v Gv G t G_t Gt)。
为什么需要这个调节 token? 这是为了控制经过多模态融合后的信息对原始信息的影响程度。有时候,多模态信息可能会太“强”,直接加回去会压制或破坏原始的图像/语言信息。调节 token 就像一个“门控”或“权重”,它可以学习如何适当地调整多模态信号的强度,确保它能有效地融入原始信号,而不是完全取代或干扰原始信号。这样可以更好地平衡不同来源的信息。
总的来说:这一步就是将融合了多模态信息的特征,通过维度恢复和残差连接的方式加回主网络,并且利用调节 token 来精细控制这种融合的强度,以达到更好的效果。

Decoding with Referential Query 使用参照查询进行解码

语言引导的多级融合: 通过在 CLIP 的不同层插入 QA 模块,可以使用多级图像特征图自适应地更新参照查询。此外,为了增强解码器中的图像特征,我们在语言引导下聚合多级视觉特征,以获得语言感知的多级图像特征。具体来说,给定一个多级图像特征集 { Z ^ v k } \{ \hat{Z}_v^k \} {Z^vk} (包括低、中和高层),其中 k ∈ K k \in \mathcal{K} kK 表示选定的层索引,我们使用 MHCA 将语言特征 Z t l a s t Z_t^{last} Ztlast (文本编码器的最终输出) 注入到每个级别的图像特征中:
H s o s = ϕ m t ( Z t l a s t ) , H v k = ϕ m v ( Z ^ v k ) (14) H_{sos} = \phi_{mt}(Z_t^{last}), \quad H_v^k = \phi_{mv}(\hat{Z}_v^k) \tag{14} Hsos=ϕmt(Ztlast),Hvk=ϕmv(Z^vk)(14)

H ^ v k = MHCA ( H v k , H s o s , H s o s ) + H v k , k ∈ K (15) \hat{H}_v^k = \text{MHCA}(H_v^k, H_{sos}, H_{sos}) + H_v^k, \quad k \in \mathcal{K} \tag{15} H^vk=MHCA(Hvk,Hsos,Hsos)+Hvk,kK(15)

其中 ϕ m t ( ⋅ ) \phi_{mt}(\cdot) ϕmt() ϕ m v ( ⋅ ) \phi_{mv}(\cdot) ϕmv() 表示用于将特征映射到相同维度的线性投影函数。此外, H s o s H_{sos} Hsos 表示 H t H_t Ht 中的 [SOS] token,它提取文本的全局信息。随后,通过简单的连接concat生成多级(multi-layer)语言感知图像特征 H ˉ v m l \bar{H}_{vml} Hˉvml,然后通过线性投影函数 ϕ v m l ( ⋅ ) \phi_{vml}(\cdot) ϕvml() 映射到原始维度:
H ˉ v m l = Concat ( { H ^ v k } ) , k ∈ K (16) \bar{H}_{vml} = \text{Concat}(\{ \hat{H}_v^k \}), \quad k \in \mathcal{K} \tag{16} Hˉvml=Concat({H^vk}),kK(16)

H v m l = ϕ v m l ( H ˉ v m l ) (17) H_{vml} = \phi_{vml}(\bar{H}_{vml}) \tag{17} Hvml=ϕvml(Hˉvml)(17)

解码: 接下来,我们首先初始化与参照查询 Q 大小相同的查询 Q’(具体解释见附5),并将它们相加以利用 Q 中的先验上下文。请注意,为了避免在初始阶段来自 Q’ 的干扰,我们将 Q’ 初始化为一个全零矩阵。然后,我们将查询与图像特征连接起来,与语言特征 H t H_{t} Ht 进行交互,以聚合条件信息并生成多模态特征图 H m m H_{mm} Hmm。这可以表示为:
O ˉ c , H ˉ m m = MHCA ( [ ϕ q ( Q ) + Q ′ ; H v m l ] , H t , H t ) (18) \bar{O}_c, \bar{H}_{mm} = \text{MHCA}([\phi_q(Q) + Q'; H_{vml}], H_t, H_t) \tag{18} Oˉc,Hˉmm=MHCA([ϕq(Q)+Q;Hvml],Ht,Ht)(18)

O c = LN ( O ˉ c ) + O ˉ c , H m m = LN ( H ˉ m m ) + H ˉ m m (19) O_c = \text{LN}(\bar{O}_c) + \bar{O}_c, \quad H_{mm} = \text{LN}(\bar{H}_{mm}) + \bar{H}_{mm} \tag{19} Oc=LN(Oˉc)+Oˉc,Hmm=LN(Hˉmm)+Hˉmm(19)

其中 ϕ q ( ⋅ ) \phi_q(\cdot) ϕq() 是 MLP 层,用于调节查询 Q 的重要性。当重要性趋近于零时,查询退化为普通查询。然后,我们将查询 O c O_c Oc 和多模态特征图 H m m H_{mm} Hmm 输入到 MHCA 层中,以提取目标嵌入 O ∈ R N q × D O \in \mathbb{R}^{N_q \times D} ORNq×D。这可以表示为:
O ˉ = MHCA ( O c , H m m , H m m ) (20) \bar{O} = \text{MHCA}(O_c, H_{mm}, H_{mm}) \tag{20} Oˉ=MHCA(Oc,Hmm,Hmm)(20)

O = LN ( ϕ r ( O ˉ ) ) + O ˉ (21) O = \text{LN}(\phi_r(\bar{O})) + \bar{O} \tag{21} O=LN(ϕr(Oˉ))+Oˉ(21)

其中 ϕ r ( ⋅ ) \phi_r(\cdot) ϕr() 表示线性投影函数。

Grounding Head: 我们在目标嵌入 O O O 之上构建了两个 MLP ( ϕ b o x ( ⋅ ) \phi_{box}(\cdot) ϕbox() ϕ c l s ( ⋅ ) \phi_{cls}(\cdot) ϕcls())。最终输出包括目标对象的预测中心坐标,表示为 b = ( x , y , h , w ) ∈ R 4 b = (x, y, h, w) \in \mathbb{R}^4 b=(x,y,h,w)R4,以及包含目标对象的预测置信度分数 y ∈ R 2 y \in \mathbb{R}^2 yR2

b = ϕ b o x ( O ) , y = ϕ c l s ( O ) (22) b = \phi_{box}(O), y = \phi_{cls}(O)\tag{22} b=ϕbox(O),y=ϕcls(O)(22)

训练目标

与 DETR 类似,我们采用二分匹配来找到预测 { b , y } \{b, y\} {b,y} 与地面真实目标 { b t g t , y t g t } \{b_{tgt}, y_{tgt}\} {btgt,ytgt} 之间的最佳匹配。在我们的例子中,类别预测是包含目标对象的查询的置信度预测。为了监督训练,我们使用框预测损失 (L1 和 GIoU) 以及匹配后的交叉熵损失。

L d e t = λ i o u L i o u ( b g t , b ) + λ L 1 ∣ ∣ b g t − b ∣ ∣ + λ c e L c e ( y g t , y ) (23) \mathcal{L}_{det} = \lambda_{iou} \mathcal{L}_{iou}(b_{gt}, b) + \lambda_{L1} ||b_{gt} - b|| + \lambda_{ce} \mathcal{L}_{ce}(y_{gt}, y)\tag{23} Ldet=λiouLiou(bgt,b)+λL1∣∣bgtb∣∣+λceLce(ygt,y)(23)

其中 λ \lambda λ 表示相应的损失权重。此外,为了鼓励每个 QA 模块中的参照查询有效地聚焦于目标相关上下文,我们还引入了类似于上述目标函数的辅助损失 L a u x \mathcal{L}_{aux} Laux 来对其进行监督。最终的训练目标可以定义为:

L f i n a l = L d e t + λ a u x L a u x (24) \mathcal{L}_{final} = \mathcal{L}_{det} + \lambda_{aux} \mathcal{L}_{aux}\tag{24} Lfinal=Ldet+λauxLaux(24)

其中 λ a u x \lambda_{aux} λaux 表示辅助损失的权重。

{ b , y } \{b, y\} {b,y}:模型的预测

  • b b b: 这是模型“预测”出来的边界框 (bounding box)。所以 b ∈ R 4 b \in \mathbb{R}^4 bR4 表示这是一个四维向量。模型会预测出多个这样的边界框,因为在“解码”部分,模型会根据多个查询 (Queries) 生成多个潜在的目标位置预测。
  • y y y: 这是模型对这个预测框的置信度 (confidence score)。这里的 y y y 是一个关于“这个查询(以及它对应的预测框)是否包含目标物体”的置信度预测。它表示模型有多确定这个框里确实是语言描述的目标物体。 y ∈ R 2 y \in \mathbb{R}^2 yR2 可能表示一个二分类的输出,比如一个数值表示是目标物体的概率,另一个表示不是的概率。
  • 所以, { b , y } \{b, y\} {b,y} 就代表了模型做出的一组预测:一个边界框和它对应的置信度。因为模型会预测多个潜在目标,所以实际上模型会输出多组 { b , y } \{b, y\} {b,y}

二分匹配 (Bipartite Matching)

  • 模型一次会输出多个预测框 { b 1 , y 1 } , { b 2 , y 2 } , … , { b N , y N } \{b_1, y_1\}, \{b_2, y_2\}, \dots, \{b_N, y_N\} {b1,y1},{b2,y2},,{bN,yN} (其中 N N N 是模型预测框的数量)。但图片中只有一个真实的目标框 b t g t b_{tgt} btgt。怎么知道哪个预测框对应的是这个真实目标呢?
  • 二分匹配就是用来解决这个问题的。它会找到一个最佳的匹配方式,将模型的预测框与真实的目标框(或“无目标”这个类别)一一对应起来。这样,我们才能知道应该用哪个预测框 b b b 和哪个预测置信度 y y y 去和真实的 b t g t b_{tgt} btgt y t g t y_{tgt} ytgt 计算损失。

主要损失 L d e t \mathcal{L}_{det} Ldet:检测损失

  • 这是用来衡量模型最终预测(也就是 Grounding Head 输出的 { b , y } \{b, y\} {b,y})与地面真实目标 { b t g t , y t g t } \{b_{tgt}, y_{tgt}\} {btgt,ytgt} 之间的差异。它由三部分组成:
    • L i o u ( b g t , b ) \mathcal{L}_{iou}(b_{gt}, b) Liou(bgt,b)GIoU 损失。衡量预测框 b b b 和真实框 b g t b_{gt} bgt 的重叠程度和相对位置。越接近,损失越小。这有助于模型预测出更准确的框的位置和形状。
    • ∣ ∣ b g t − b ∣ ∣ L 1 ||b_{gt} - b||_{L1} ∣∣bgtbL1L1 损失。直接计算预测框坐标和真实框坐标之间的绝对差值之和。也是用来惩罚框位置和大小的偏差。
    • L c e ( y g t , y ) \mathcal{L}_{ce}(y_{gt}, y) Lce(ygt,y)交叉熵损失。衡量预测置信度 y y y 和真实置信度 y t g t y_{tgt} ytgt 之间的差异。如果模型预测的置信度与真实情况(比如,被匹配到真实目标时预测高分,未匹配到时预测低分)相符,损失就小。
  • 公式中的 λ i o u , λ L 1 , λ c e \lambda_{iou}, \lambda_{L1}, \lambda_{ce} λiou,λL1,λce 是权重,用来调整这三部分损失在总损失中的重要性。

辅助损失 L a u x \mathcal{L}_{aux} Laux

  • 作用: 主要损失 L d e t \mathcal{L}_{det} Ldet 只监督模型最后的输出。但这个模型有很多中间层(特别是 QA 模块),这些中间层也在处理信息并生成查询。辅助损失的作用就是给这些中间层的输出也提供监督信号。
  • 为什么需要辅助损失? 如果只看最终结果,中间层可能会学到一些对最终预测没有直接贡献,甚至是有害的东西。通过在中间层也计算一个损失,并让它向着正确方向优化,可以确保模型从早期阶段就开始学习聚焦于目标相关的上下文,使得整个模型的训练过程更稳定、更有效,并可能提升最终性能。
  • 具体内容: 论文中提到 L a u x \mathcal{L}_{aux} Laux “类似于上述目标函数”,这意味着可能在每个 QA 模块的输出(或者某种基于 QA 输出的预测)上,也计算一个类似于 L d e t \mathcal{L}_{det} Ldet 的损失,比如预测一个临时的框和置信度,并与地面真实目标进行比较。这强制要求每个 QA 层产生的查询都能更好地指向目标物体。
  • λ a u x \lambda_{aux} λaux:辅助损失的权重,用来平衡它与主要损失的重要性。

扩展到密集 Grounding

除了对象级别的 Grounding,我们的方法可以通过加入一个分割头轻松扩展到密集 Grounding 任务。具体来说,类似于 MaskFormer,我们利用 MLP 将目标嵌入 O O O 转换为掩码嵌入 M ∈ R N q × D M \in \mathbb{R}^{N_q \times D} MRNq×D。二值掩码预测 s i ∈ [ 0 , 1 ] H × W s_i \in [0, 1]^{H \times W} si[0,1]H×W 然后通过掩码嵌入 M M M 和多模态特征图 H m m H_{mm} Hmm 之间的点积计算得到,接着是一个 sigmoid 激活。在训练过程中,我们使用掩码预测损失 (Focal loss 和 Dice loss),其定义如下:

L s e g = λ f o c a l L f o c a l ( s g t , s ) + λ d i c e L d i c e ( s g t , s ) (25) \mathcal{L}_{seg} = \lambda_{focal} \mathcal{L}_{focal}(s_{gt}, s) + \lambda_{dice} \mathcal{L}_{dice}(s_{gt}, s) \tag{25} Lseg=λfocalLfocal(sgt,s)+λdiceLdice(sgt,s)(25)

其中 s g t s_{gt} sgt 表示ground-truth掩码。(具体解释见附3)

密集 Grounding”(Dense Grounding)是一个更精细的任务。它的目标不仅仅是用一个框把目标物体围起来,而是要预测出一个像素级别的掩码(segmentation mask)。这个掩码就像是把目标物体的精确轮廓勾勒出来,指出图片中的每一个像素点是属于目标物体还是背景。

  1. 增加“分割头”(Segmentation Head): 原来的模型有一个“Grounding Head”,用来预测边界框 b b b 和置信度 y y y。现在为了做分割,模型在 Grounding Head 的位置 又增加了一个专门负责预测分割掩码的模块,就叫做“分割头”。
  2. 利用“目标嵌入 O O O”: 之前的步骤中,模型通过一系列计算,最后得到了“目标嵌入” O O O 。这些 O O O 是模型对目标物体的高度抽象表示,包含了它认为目标物体可能在哪里、长什么样等信息。现在,这个新的分割头就利用了这些已经学到的目标嵌入 O O O 作为输入。
  3. O O O 转换为“掩码嵌入” M M M 分割头里的第一步是使用一个 MLP将目标嵌入 O O O 转换成“掩码嵌入” M M M。这是一个专门为生成掩码而准备的特征表示。虽然 O O O 包含了目标信息,但 M M M 是把这些信息调整和组织成最适合用来预测像素级别掩码的形式。
  4. 计算掩码预测 s s s 这是核心步骤。模型通过计算掩码嵌入 M M M 和之前生成的多模态特征图 H m m H_{mm} Hmm 之间的点积(dot product)来得到像素级别的预测 s s s
    • 多模态特征图 H m m H_{mm} Hmm 这个 H m m H_{mm} Hmm 是融合了图像信息和语言信息的特征图,它保留了图像的空间结构(特征的位置对应图片中的位置)。
    • 点积的意义: 点积可以看作是一种相似度计算。模型对每个潜在的目标(对应一个掩码嵌入 M M M),会拿着这个 M M M H m m H_{mm} Hmm 中的每一个像素位置对应的特征计算点积。如果某个像素位置的特征与这个目标对应的掩码嵌入 M M M “很相似”,点积的值就会比较高,说明这个像素很可能属于这个目标物体。反之则点积值低。
    • Sigmoid 激活: 点积计算出来的原始值可能很大或很小,通过 sigmoid 函数将其压缩到 [0, 1] 的范围内。这样,每个像素位置的输出值就表示该像素属于目标物体的概率。这就是最终的二值掩码预测 s s s,它的尺寸和原始图像(或某个尺度的特征图)相同( H × W H \times W H×W),每个位置的值代表属于目标的概率。
  5. 训练时的损失函数 L s e g \mathcal{L}_{seg} Lseg 既然现在预测的是像素掩码,训练时用来衡量好坏的标准也不同了。不再使用边界框损失,而是使用专门用于分割任务的损失函数。
    • 真实掩码 s g t s_{gt} sgt 这是数据集中为密集 Grounding 任务提供的真实标注,是一个像素级别的图,准确地标出了目标物体的每一个像素。
    • 掩码预测损失 L s e g \mathcal{L}_{seg} Lseg 包括两种常用的分割损失:
      • Focal Loss ( L f o c a l \mathcal{L}_{focal} Lfocal): 这种损失函数特别适合处理像素类别不平衡的问题(比如图像中背景像素远多于目标物体像素)。它可以让模型更关注那些难以区分的像素。
      • Dice Loss ( L d i c e \mathcal{L}_{dice} Ldice): 这种损失函数衡量预测掩码 s s s 和真实掩码 s g t s_{gt} sgt 之间的重叠程度。重叠得越多,损失越小。
    • 总的掩码预测损失 L s e g \mathcal{L}_{seg} Lseg 是这两种损失的加权求和 ( λ f o c a l , λ d i c e \lambda_{focal}, \lambda_{dice} λfocal,λdice 是权重)。训练时,模型会尝试最小化这个 L s e g \mathcal{L}_{seg} Lseg,以便预测的掩码 s s s 尽可能接近真实的 s g t s_{gt} sgt

扩展到密集 Grounding 是在原来模型的基础上,利用它已经学到的目标特征 O O O 和多模态融合特征 H m m H_{mm} Hmm增加一个专门的分割头。这个分割头通过 MLP 将 O O O 变成掩码嵌入 M M M,然后用 M M M H m m H_{mm} Hmm 计算点积来预测像素级别的掩码。同时,训练时改用或增加专门的分割损失来指导模型学习预测准确的像素掩码。本质上,它是复用了模型前面学习到的理解图像和语言的能力,并在这个基础上,通过一个额外的模块和相应的损失函数,将其应用于更精细的像素级别预测任务。

讨论

如图 5 所示,QA 模块中的注意力图展示了参照查询如何捕获目标相关上下文的细化过程。最初,注意力图看起来比较嘈杂,但逐渐聚焦于目标相关上下文,例如图 (a) 中的沙发。通过引入参照查询,解码器中的注意力图准确地集中在目标对象上。此外,需要注意的是,由于 QA 模块中的特征维度较低,参照查询可能不会精确地聚焦在目标对象上,但它仍然捕获了目标相关信息。

figure5

在这项工作中,我们的目标是探索如何进一步优化查询的学习过程。为了减少由普通查询带来的学习困难,我们引入了一个简单的查询自适应模块,以自适应地捕获目标相关上下文并迭代地对其进行优化精炼。如图 5 所示,每个查询自适应模块产生的注意力图与我们的目标一致:逐步聚焦于目标相关上下文,并为解码器提供先验上下文。值得注意的是,虽然“多级”、“适配器”和“自注意力”可能在其他研究领域得到广泛应用,但我们的方法旨在整合它们来解决视觉 Grounding 任务中的挑战,而不是设计一个特定的模块来单独实现上述功能。

实验部分

数据集和评估指标

RefCOCO/RefCOCO+/RefCOCOg: RefCOCO 包含 19,994 张图片,其中有 50,000 个参照对象,分为训练集、验证集、testA 集和 testB 集。类似地,RefCOCO+ 包含 19,992 张图片,其中有 49,856 个参照对象和 141,564 个参照表达。与 RefCOCO 相比,它包含更多属性而非绝对位置,并且具有相同的划分。RefCOCOg 包含 25,799 张图片,其中有 49,856 个参照对象和表达。遵循一种常见的划分版本 ,即训练集、验证集和测试集。

Flickr30k: Flickr30k Entities 包含 31,783 张图片和 158k 个带有 427k 标注短语的字幕。我们将图片分为 29,783 张用于训练,1000 张用于验证,1000 张用于测试,并在测试集上报告性能。

ReferItGame: ReferItGame 包含 20,000 张图片,其中有 120,072 个参照表达,对应 19,987 个参照对象。

评估指标: 对于参照表达理解 (REC),我们使用 Prec@0.5 评估协议来评估准确率,这与之前的工作一致。在这种评估中,如果预测的边界框与其地面真实边界框的 Intersection-over-Union (IoU) 大于 0.5,则认为该预测是正确的。对于参照表达分割 (RES),我们报告预测的分割掩码和地面真实掩码之间的 Mean IoU (MIoU)。

实现设置细节

根据 Transvg(2021年)和Dynamic mdetr(2023年),输入图像的分辨率被调整为 640 × 640。我们使用预训练的 CLIP 作为骨干网络来提取图像和语言特征,并在训练期间冻结其参数。模型使用 AdamW 优化器进行端到端优化,训练 40 个 epoch,批大小为 32。我们将学习率设置为 1e-4 ( 1 × 1 0 − 4 1 \times 10^{-4} 1×104),权重衰减设置为 1e-2。实验在 V100 GPU 上进行。损失权重 λ i o u \lambda_{iou} λiou λ L 1 \lambda_{L1} λL1 λ c e \lambda_{ce} λce λ a u x \lambda_{aux} λaux 分别设置为 3.0、1.0、1.0 和 0.1。对于密集 Grounding,我们将参数 λ f o c a l \lambda_{focal} λfocal λ d i c e \lambda_{dice} λdice 分别设置为 5.0 和 1.0。

主实验结果

table1
table2

消融实验结果

在 RefCOCOg 数据集上进行消融研究,以验证我们提出的方法中每个部分的有效性。
在这里插入图片描述

  • 多级融合层的影响: 在表 4 中,我们分析了融合层在解码器中的影响。我们首先进行了只使用单级图像特征的实验,然后进行多级特征的实验。结果表明,利用多级特征显著提高了性能,这表明低级和中级特征对高级特征形成了补充。此外,使用 {4, 8, 12} 实现了最佳性能,这也是我们在实验中采用的配置。

  • QA 位置的影响: 如表 5 所示,首先,我们可以观察到移除 QA 会导致性能急剧下降,这表明了 QA 的有效性。然后我们探索了 QA 在 CLIP 中不同位置的影响,以确定 QA 应该放置在哪里进行消融研究:{4, 8, 12} 和 {4, 6, 8, 10, 12},以及 {2, 4, 6, 8, 10, 12}。结果表明,当我们使用 {4, 6, 8, 10, 12} 配置时性能最佳。因此,我们在实验中默认使用此位置。

  • 辅助损失的影响: 在表 6 的第二行,我们通过有无辅助损失进行实验,结果证明了辅助损失的有效性。通过使用辅助损失,参照查询可以更有效地捕获目标相关的视觉上下文。

  • 可学习查询的影响: 在表 6 的第三行,我们验证了可学习查询(learnable query)的有效性。我们用 QA 模块生成的随机初始化查询或自然语言embedding替换了可学习查询,同时保持其他模块不变。我们可以观察到引入先验查询带来了显著的性能提升。这一结果表明,先验查询有助于解码器更准确地定位目标对象。此外,我们研究了参照查询的准确性,它们旨在为解码器提供先验信息。由于 QA 模块的通道维度较低,参照查询可能无法准确预测目标的坐标

在这里插入图片描述

  • 收敛曲线: 图 4 展示了我们提出的方法与开源 DETR 类视觉 Grounding 方法的收敛曲线。值得注意的是,我们的方法展示了加速的训练收敛速度,将训练时间缩短了一半,同时性能也优于其他现有方法。
  • RefFormer 方向的影响: 在 RefFormer 中,QA 模块可以作为适配器,将特定知识注入到冻结的 CLIP 模型中。在表 7 中,我们研究了 QA 模块的特征流方向。我们发现使用双向(bi-directional interaction)方法可以实现最佳性能。通过 QA 模块,语言特征逐步聚合相关的视觉上下文信息。正如 CARIS: Context-aware referring image segmentation 所指出的,将丰富的视觉上下文整合到语言特征中有助于实现强大的视觉-语言对齐,并更好地指示目标对象。
  • 可学习查询数量的影响: 我们在图 6 中展示了根据可学习查询数量 N q N_q Nq 的 Prec@0.5 性能。当我们采用 N q = 3 N_q = 3 Nq=3 时,性能最佳。然而,进一步增加只会使指标略有改善,因为大量的 N q N_q Nq 增加了模型的难度。因此,我们在实验中默认将 N q = 3 N_q = 3 Nq=3

在这里插入图片描述

可视化实验

如图 7 所示,参照查询逐渐聚焦于目标对象,并有效地为解码器提供了目标相关的上下文。这些结果证明了我们提出方法的有效性。
在这里插入图片描述

总结

总结与讨论: 本文提出了一种新颖的方法,称为 RefFormer,它可以无缝地集成到 CLIP 中。RefFormer 不仅可以生成参照查询,为解码器提供与目标相关的上下文,还可以作为适配器,保留 CLIP 的原始知识并降低训练成本。大量的实验证明了我们方法的有效性,可视化结果展示了我们提出的 RefFormer 的细化过程。

局限性: 尽管我们的方法是专门为 REC 任务设计的,并在 REC 中超越了现有的 SOTA(State-Of-The-Art,最新水平)方法,但在 RES 任务方面仍有很大的改进空间。这是因为我们尚未针对 RES 任务对我们的方法进行专门优化。

其他知识附录

附1:在预定义锚框上使用滑动窗口进行密集预测

在一张照片里找到所有的猫。你并不知道猫可能在哪里,也不知道猫是大是小,是躺着(长宽比大)还是蹲着(长宽比接近1:1)。传统两阶段目标检测框架通常包括候选区域生成阶段(Region Proposal Stage)和目标分类与定位阶段(Detection Stage)。在第一阶段,模型使用如选择性搜索(Selective Search)或区域建议网络(Region Proposal Network, RPN)等方法,从整幅图像中快速生成一组具有较高目标可能性的候选区域(Region Proposals)(即坐标、置信度分数(objectness score):表示该区域包含某种目标的可能性;锚框(anchor boxes)索引:表示该候选框是从哪个预设锚框偏移出来的),这些区域可能包含目标物体的轮廓或结构特征。在第二阶段,检测器对这些候选区域进行进一步的特征提取与分类,同时回归出更精确的目标边界框,从而实现目标的最终识别与定位。

单阶段 + 锚框 + 滑动窗口 + 密集预测”的方法:与两阶段方法不同,检测器只经过一次前向传播,直接输出最终结果(边界框 + 类别),无需候选区域生成。每个特定位置预设多个不同尺寸与长宽比的框,称为锚框(anchors)或默认框(default boxes),模型需要为每个锚框预测:

  • 是否包含目标(分类分支)
  • 如果有,目标类别是什么
  • 该锚框需要偏移多少才能拟合真实目标(回归分支)
  • 类似于每个锚框都问:“如果我这里有个物体,我该变成什么形状、属于哪一类?”
  1. 预定义锚框 (Pre-defined Anchor Boxes):

    • 不同大小和形状的矩形框(比如一个小的正方形模板,一个大的正方形模板,一个高的长方形模板,一个宽的长方形模板等等)。
    • 在开始找猫之前,你先把这些模板密密麻麻地铺满整张照片。你可以在照片上每隔一定的距离(比如每10个像素)就放一套这样的模板。所以整张照片上会有非常多非常多的、各种大小和形状的模板框。这些就是“预定义锚框”。 因为对象的大小和形状变化很大,提前准备好各种可能的尺寸和比例,是为了能“套住”各种不同的对象。
  2. 滑动窗口 (Sliding Window):

    • 虽然实现上是卷积网络全图处理,但本质上可以看作在特征图上以滑动窗口的方式遍历每个位置,并在这些位置做预测:
    • 每个位置相当于一个小区域(感受野)在“观察图像”
    • 每个区域对应多个锚框,每个锚框输出一组预测
  3. 密集预测 (Dense Predictions):

    • 当这个“滑动窗口”或者说网络的处理区域移动到图片上的某个位置时,它会同时检查覆盖在这个位置上的所有预定义锚框
    • 对于这个位置上的每一个预定义锚框,网络都会进行预测
      • 预测 1: 这个锚框里有没有我们要找的对象(比如猫)?如果有,它是猫的概率有多大?
      • 预测 2: 如果里面有猫,这个锚框的位置和大小需要调整多少,才能更精确地框住这只猫?
    • 因为这个检查和预测过程发生在图片上所有被锚框覆盖的位置,并且在每个位置都要检查所有预设的锚框,所以产生的预测结果数量非常庞大,几乎覆盖了图片上的每一个角落和各种可能的形状。这就是“密集预测”。典型的做法是在不同层级的特征图上进行预测(多尺度特征融合,如 FPN)。换句话说,相比于“精挑细选”候选框,两阶段方法的“少量精看”,单阶段是“眼睛到处看,每个地方都问一句”。
  • 做什么? 就是在图片上密密麻麻地放一堆各种大小形状的模板(锚框),然后网络系统地(通过滑动窗口的方式)检查图片上的每一个地方,对于每个地方的每一个模板,都预测一下“这里有没有猫,这个模板需要怎么改才能准确框住猫”。
  • 为什么要这么做?
    • 系统地搜索: 这种方法的好处在于它非常系统和全面。通过预设各种尺寸和比例的锚框并检查图片上的每一个位置,它几乎穷尽了对象可能出现的所有位置和大小的可能性。
    • 单阶段: 相比于需要先找候选区域的两阶段方法,这种方法将“找可能区域”和“判断是不是对象并精修”合并成了一个步骤,通常可以更快一些,流程也更简单。
特性两阶段(Faster R-CNN)单阶段(YOLO/SSD)
检测速度慢(多阶段)快(端到端)
检测精度高(尤其是小目标)相对略低
推理复杂度
工程部署相对复杂更轻量、适合实时场景

局限性:
这种方法有个缺点:它主要是在处理局部信息(看每个小窗口里的锚框),很难有效地理解对象之间的关系(比如猫和椅子是“坐”的关系,两只猫是“挨着”的关系),或者对象与大背景的联系。在视觉定位任务中,语言描述常常包含这些复杂的对象关系,如果方法不能理解这些关系,就很难准确地找到指定的对象,所以性能会受到影响,被称为“次优”。而像 DETR 那样不依赖锚框,使用全局注意力机制的方法,就更容易捕捉到这些复杂的关系。

附2:注意力池化(Attention Pooling)

注意力池化就像一个“智能”的筛选器,它会“阅读”序列中的所有向量,然后根据每个向量的内容判断它的重要性,最后把它们“混合”起来,但重要的向量在混合中占的比例更高。

注意力池化是一种加权平均的方法,用来将一个包含多个向量的序列(比如语言编码器输出的每一个词元的上下文嵌入向量序列)压缩成一个单一的固定长度向量。它的核心思想是:不是简单地平均所有向量,而是学习给序列中不同的向量分配不同的重要性权重,然后根据这些权重进行加权求和。

为什么需要它?

当你有一系列向量(比如,句子中每个词经过编码后的向量),你想用一个单一的向量来代表整个序列的意义。简单的做法可以是取平均值或最大值(这就是前面提到的平均池化和最大池化)。但这样做的缺点是,它们对序列中的所有元素一视同仁(平均池化)或者只关注最突出的元素(最大池化),无法灵活地根据上下文或任务需求来动态地决定哪些元素更重要。

注意力池化解决了这个问题,它让模型自己学习去“注意”序列中哪些部分更关键,从而在生成最终表示时给予这些关键部分更高的权重。

具体是如何做的?

假设我们有一个由 n n n 个向量组成的输入序列 H = [ h 1 , h 2 , . . . , h n ] H = [h_1, h_2, ..., h_n] H=[h1,h2,...,hn],其中 h i h_i hi 是序列中第 i i i 个元素的向量表示(比如第 i i i 个词的上下文嵌入)。注意力池化生成一个输出向量 O O O 的过程如下:

  1. 计算每个向量的“得分” (Scoring):

    • 对于序列中的每一个向量 h i h_i hi,通过一个小型的神经网络层(通常是一个全连接层)计算出一个标量值 s i s_i si。这个 s i s_i si 可以被看作是 h i h_i hi 对于最终表示的原始重要性得分
    • 这个计算过程可以简单表示为: s i = score ( h i ) s_i = \text{score}(h_i) si=score(hi) score \text{score} score 函数通常包含一些可学习的权重和偏置。一个常见的形式是先进行线性变换和非线性激活,再映射到单个得分: s i = v T tanh ( W h i + b ) s_i = v^T \text{tanh}(W h_i + b) si=vTtanh(Whi+b),其中 W , b , v W, b, v W,b,v 都是模型需要学习的参数。
  2. 将得分转换为“注意力权重” (Normalizing Scores into Attention Weights):

    • 原始得分 s 1 , s 2 , . . . , s n s_1, s_2, ..., s_n s1,s2,...,sn 的值范围是不定的,也不能直接作为权重。我们需要将它们转换成一组正数,并且这些正数加起来等于 1
    • 这通常通过 Softmax 函数来实现。Softmax 函数会将所有得分进行指数化,然后除以它们的总和,得到每个向量的注意力权重 α i \alpha_i αi
      α i = exp ⁡ ( s i ) ∑ j = 1 n exp ⁡ ( s j ) \alpha_i = \frac{\exp(s_i)}{\sum_{j=1}^{n} \exp(s_j)} αi=j=1nexp(sj)exp(si)
    • 这样得到的 α i \alpha_i αi 就是归一化后的注意力权重。 α i \alpha_i αi 值越大,说明对应的向量 h i h_i hi 越重要。
  3. 进行加权求和 (Weighted Sum):

    • 最后,用计算出的注意力权重 α i \alpha_i αi 对原始输入向量 h i h_i hi 进行加权求和,得到最终的输出向量 O O O
      O = ∑ i = 1 n α i h i O = \sum_{i=1}^{n} \alpha_i h_i O=i=1nαihi
    • 这个输出向量 O O O 就是通过注意力机制从整个序列中提取出的固定长度的表示,它倾向于包含那些被赋予更高注意力权重的向量的信息

优势:

  • 可学习性: 模型的神经网络会学习如何计算得分 s i s_i si,这意味着模型可以根据训练任务的需要,自动学习哪些输入向量(比如句子中的哪些词)是更重要的。
  • 动态性: 注意力权重是根据当前的输入序列动态计算的,同一个词在不同句子中的重要性可能不同,模型能够捕捉这种差异。
  • 更好的表示能力: 生成的固定长度向量 O O O 能够更精确地反映序列中最重要的信息,从而比简单的平均或最大池化有更强的表示能力。

附3:如何让一个原本只会给物体“画框框”的AI模型,升级为能够给物体“精细描绘轮廓”的模型

“对象级别的 Grounding” vs “密集 Grounding”

  • “对象级别的 Grounding”(Object-level Grounding)

    • 通俗理解:好比你告诉计算机:“图片里那只猫在哪里?” 计算机只需要画一个框框把猫圈出来就行。它告诉你猫的“大概位置和范围”。
    • 专业一点:这个任务是根据文本描述(比如“红色的球”)在图像中定位出对应的物体,并通常用一个边界框(bounding box)来标示该物体。
  • “密集 Grounding”(Dense Grounding)

    • 通俗理解:还是那个问题:“图片里那只猫在哪里?” 这次计算机不仅要找到猫,还要非常精细地把猫的每一个像素点都描绘出来,就像给猫做一个完美贴合的“轮廓剪影”。
    • 专业一点:这个任务要求更精细的定位,达到像素级别。它不仅仅是找到物体,还要精确地分割出物体的确切形状和边界。这就是 视觉分割(Visual Segmentation) 的核心。

这段话的核心思想:从“画框框”升级到“描轮廓”

作者说,他们的方法本来是做“对象级别 Grounding”(画框框)的,但通过一些改进,可以轻松地升级去做“密集 Grounding”(描轮廓/视觉分割)。

如何实现升级?——加入“分割头”(Segmentation Head)

  • “分割头”是什么?
    • 通俗理解:你可以把它想象成给原来只会“画框框”的AI模型加装了一个“精细描绘工具包”。这个工具包就是“分割头”。
    • 专业一点:在神经网络模型中,“头”(Head)指的是网络末端专门负责输出特定任务结果的部分。所以,“分割头”就是模型中专门设计用来输出像素级分割结果(即哪个像素属于哪个物体)的组件。

具体是怎么做的?——借鉴 MaskFormer 的思路

作者借鉴了一种叫做 MaskFormer 的先进分割模型的方法。步骤如下:

  1. 第一步:从“物体概念”到“轮廓蓝图”

    • 原文:“我们利用 MLP 将目标嵌入 O O O 转换为掩码嵌入 M ∈ R N q × D M \in \mathbb{R}^{N_q \times D} MRNq×D。”
    • 通俗理解
      • “目标嵌入 O O O”:这是模型在“画框框”阶段对物体理解的精华,可以想成是“这是一只猫”的抽象概念表示。
      • “MLP”(多层感知机):一种常见的神经网络模块,可以进行信息转换和提炼。
      • “掩码嵌入 M M M”:通过MLP这个转换器,把“这是一只猫”的抽象概念,转换成一种更具体的、准备用来画出猫轮廓的“蓝图”或“指令集”。
    • 专业一点:模型已经通过前面的步骤为每个可能的目标(由 N q N_q Nq 个查询表示)生成了目标嵌入 O O O。现在,一个MLP网络将这些目标嵌入 O O O(代表高级语义信息)转换为掩码嵌入 M M M。这些掩码嵌入 M M M 是专门为生成像素级掩码而优化的特征表示。每个查询会对应一个掩码嵌入。
  2. 第二步:用“轮廓蓝图”在“融合地图”上找到精确的像素点

    • 原文:“二值掩码预测 s i ∈ [ 0 , 1 ] H × W s_i \in [0, 1]^{H \times W} si[0,1]H×W 然后通过掩码嵌入 M M M 和多模态特征图 H m m H_{mm} Hmm 之间的点积计算得到,接着是一个 sigmoid 激活。”
    • 通俗理解
      • “多模态特征图 H m m H_{mm} Hmm”:这是一张融合了图像信息和文本信息(比如“那只在沙发上的猫”)的“详细地图”。它告诉我们图像中每个位置有什么特征,以及这些特征与文本描述的关联度。
      • “点积”和“sigmoid激活”:这是数学运算,你可以理解为拿着“轮廓蓝图 M M M”(代表要找的物体,比如猫的特定指令)去“详细地图 H m m H_{mm} Hmm”上逐个像素点比对。比对的结果会告诉我们,地图上的每个像素点有多大概率是属于我们要找的“猫”的轮廓的一部分。
      • “二值掩码预测 s i s_i si”:这就是最终的“描绘结果”,是一张跟原图一样大的图,图上每个点的值都在0到1之间。值越接近1,说明这个点越有可能是“猫”的一部分;越接近0,就越不可能是。如果把值大于某个阈值(比如0.5)的点标出来,就得到了猫的轮廓。
    • 专业一点
      • H m m H_{mm} Hmm 是一个包含了丰富空间信息和跨模态语义信息的特征图。
      • 对于每个查询(现在由其对应的掩码嵌入 M M M 代表),模型将其与 H m m H_{mm} Hmm 中的每个像素级特征进行点积运算。点积可以衡量两组向量的相似性或相关性。这里,它衡量的是每个像素的特征与该查询所代表的“物体掩码原型”的匹配程度。
      • Sigmoid激活函数将点积的结果压缩到 [ 0 , 1 ] [0, 1] [0,1] 范围内,解释为该像素属于该查询所对应物体的概率。这样就为每个查询生成了一个概率图 s i s_i si(大小为 H × W H \times W H×W,即图像的高和宽)。

如何训练模型学会“描轮廓”?——使用损失函数

  • 原文:“在训练过程中,我们使用掩码预测损失 (Focal loss 和 Dice loss),其定义如下: L s e g = λ f o c a l L f o c a l ( s g t , s ) + λ d i c e L d i c e ( s g t , s ) \mathcal{L}_{seg} = \lambda_{focal} \mathcal{L}_{focal}(s_{gt}, s) + \lambda_{dice} \mathcal{L}_{dice}(s_{gt}, s) Lseg=λfocalLfocal(sgt,s)+λdiceLdice(sgt,s),其中 s g t s_{gt} sgt 表示ground-truth掩码。”
  • 通俗理解
    • “ground-truth掩码 s g t s_{gt} sgt”:这是“标准答案”,即人工预先画好的、完美精确的猫的轮廓。
    • “损失函数”:这是一个“评分标准”。它比较计算机画的轮廓 s s s 和标准答案 s g t s_{gt} sgt 之间的差距。差距越大,得分越低(损失越大)。
    • “Focal loss 和 Dice loss”:这是两种比较高级的评分标准,特别适合用来评判“描轮廓”画得好不好。
    • 训练过程:计算机不断地尝试去画轮廓,然后用这个评分标准来打分。如果画得不好(损失大),计算机就会调整自己的“绘画技巧”(模型参数),争取下次画得更好。如此反复,直到它能画出非常接近标准答案的轮廓。
  • 专业一点
    • s g t s_{gt} sgt 是人工标注的真实分割掩码。
    • 损失函数 L s e g \mathcal{L}_{seg} Lseg 用来衡量模型预测的掩码 s s s 与真实掩码 s g t s_{gt} sgt 之间的差异。
    • Focal Loss 和 Dice Loss 是常用于分割任务的损失函数,它们能有效处理前景背景像素不平衡等问题,并关注分割区域的重叠度。
    • λ f o c a l \lambda_{focal} λfocal λ d i c e \lambda_{dice} λdice 是权重,用于平衡这两种损失的重要性。
    • 通过最小化这个损失函数,模型会学习调整其参数(包括MLP、掩码嵌入以及之前所有相关模块的参数),以便生成的预测掩码 s s s 尽可能地接近真实掩码 s g t s_{gt} sgt

核心步骤是:

  1. 加装“精细描绘工具包”(分割头)。
  2. 将对物体的抽象理解(目标嵌入 O O O)转换为 “轮廓蓝图” (掩码嵌入 M M M)。
  3. 拿着“轮廓蓝图 M M M”在包含图像和文本细节的 “融合地图” (多模态特征图 H m m H_{mm} Hmm)上进行比对,为每个像素点打出属于目标物体的概率,从而得到**“描绘结果”**(二值掩码预测 s s s)。
  4. 通过与**“标准答案”(真实掩码 s g t s_{gt} sgt)进行比较,并使用特定的“评分标准”**(Focal loss 和 Dice loss),不断优化模型的“绘画技巧”。

附4:两个 learnable regulation tokens 的作用

简单来说,这两个可学习的调节 token ( r v , r t \textbf{r}_v, \textbf{r}_t rv,rt) 就像两个智能的“阀门开关”和“信息采集员”。它们首先帮助从对方模态采集有用的上下文信息,形成调节信号;然后利用这些信号作为“阀门”,控制经过多模态融合和提炼后的特征的“流量”,确保最终注入到主干网络的信息既包含丰富的跨模态关联,又不会破坏各模态自身的宝贵特征。

两个learnable regulation tokens:作为可学习的“调节器”和“信息整合者”,对特征进行精炼和平衡。

  1. 初步信息处理

    • 图像经过 CLIP 的图像编码器得到多层级的视觉特征。
    • 文本(指代表达式)经过 CLIP 的文本编码器得到文本特征 H t \textbf{H}_t Ht
    • 一组可学习的查询 Q 0 \textbf{Q}^0 Q0 (Queries) 被随机初始化。
  2. 查询自适应模块 (QA Module) - 内部循环
    QA 模块被多次(例如3次,对应CLIP的第4, 6, 9层)插入到 CLIP 的不同层级之间。在每个 QA 模块中:

    • CAMF 模块 (Condition Aggregation and Multi-modal Fusion)
      • 上一轮的查询 Q i − 1 \textbf{Q}^{i-1} Qi1 与当前的视觉特征 F v i \textbf{F}_v^i Fvi(来自CLIP特定层)和文本特征 F t i \textbf{F}_t^i Fti(可能是全局文本特征或其变体)进行交互。
      • 这里的交互主要是通过一个多头跨注意力层(MHCA)来实现,如公式 (7): r ˉ v , Q ~ c i , F ~ v i = M H C A ( [ r v ; Q i − 1 ; F v i ] , F t i , F t i ) \bar{\textbf{r}}_v, \tilde{\textbf{Q}}_c^i, \tilde{\textbf{F}}_v^i = MHCA([\textbf{r}_v; \textbf{Q}^{i-1}; \textbf{F}_v^i], \textbf{F}_t^i, \textbf{F}_t^i) rˉv,Q~ci,F~vi=MHCA([rv;Qi1;Fvi],Fti,Fti)
        • 注意这里的变化:根据公式 (7), r v \textbf{r}_v rv Q i − 1 \textbf{Q}^{i-1} Qi1 以及 F v i \textbf{F}_v^i Fvi 被拼接起来作为 MHCA 的查询 (Query),而文本特征 F t i \textbf{F}_t^i Fti 同时作为键 (Key) 和值 (Value)。这意味着 r v \textbf{r}_v rv 在这里参与了“条件聚合”,即帮助从文本特征中提取与视觉相关的条件信息。
      • 类似地,公式 (9) 中: r ˉ t , F ~ t i = M H C A ( [ r t ; F t i ] , F v i , F v i ) \bar{\textbf{r}}_t, \tilde{\textbf{F}}_t^i = MHCA([\textbf{r}_t; \textbf{F}_t^i], \textbf{F}_v^i, \textbf{F}_v^i) rˉt,F~ti=MHCA([rt;Fti],Fvi,Fvi) (这里我根据对称性推测,原文公式(7)和(9)对 r v \textbf{r}_v rv r t \textbf{r}_t rt的处理相似, r t \textbf{r}_t rt F t i \textbf{F}_t^i Fti 拼接作为查询,去“查询”视觉特征 F v i \textbf{F}_v^i Fvi)
        • 修正理解:根据公式 (7) 和 (9), r v \textbf{r}_v rv r t \textbf{r}_t rt作为 query 的一部分,与各自模态的特征 ( Q i − 1 , F v i \textbf{Q}^{i-1}, \textbf{F}_v^i Qi1,Fvi 对于 r v \textbf{r}_v rv F t i \textbf{F}_t^i Fti 对于 r t \textbf{r}_t rt) 一起,去查询另一种模态的特征
          • 公式 (7): [ r v ; Q i − 1 ; F v i ] [\textbf{r}_v; \textbf{Q}^{i-1}; \textbf{F}_v^i] [rv;Qi1;Fvi] 作为查询,去查询文本特征 F t i \textbf{F}_t^i Fti。输出 r ˉ v \bar{\textbf{r}}_v rˉv r v \textbf{r}_v rv 对应的输出, Q ~ c i \tilde{\textbf{Q}}_c^i Q~ci Q i − 1 \textbf{Q}^{i-1} Qi1 对应的输出, F ~ v i \tilde{\textbf{F}}_v^i F~vi F v i \textbf{F}_v^i Fvi 对应的输出。
          • 公式 (9): [ r t ; F t i ] [\textbf{r}_t; \textbf{F}_t^i] [rt;Fti] 作为查询,去查询视觉特征 F v i \textbf{F}_v^i Fvi。输出 r ˉ t \bar{\textbf{r}}_t rˉt r t \textbf{r}_t rt 对应的输出, F ~ t i \tilde{\textbf{F}}_t^i F~ti F t i \textbf{F}_t^i Fti 对应的输出。
      • 然后这些输出会经过层归一化和残差连接(公式 (8))。
    • TR 模块 (Target-related Context Refinement)
      • CAMF 模块输出的查询 Q ~ c i \tilde{\textbf{Q}}_c^i Q~ci 和多模态增强特征 F ~ v i , F ~ t i \tilde{\textbf{F}}_v^i, \tilde{\textbf{F}}_t^i F~vi,F~ti 会进入 TR 模块。
      • 查询 Q ~ c i \tilde{\textbf{Q}}_c^i Q~ci 会与增强后的视觉特征 F ~ v i \tilde{\textbf{F}}_v^i F~vi 交互,以优化查询(公式 (10))。
      • 增强后的视觉特征 F ~ v i \tilde{\textbf{F}}_v^i F~vi 和文本特征 F ~ t i \tilde{\textbf{F}}_t^i F~ti 会进行自注意力(或与彼此交互),进一步增强其上下文表示,得到 G v i G_v^i Gvi G t i G_t^i Gti(公式 (11), (12))。
  3. 特征调节与上投影 (Up-projection) - 关键作用点

    • 根据公式 (13):“…we utilize the regulation token to modulate the features G v G_v Gv and G t G_t Gt, which helps prevent the multi-modal signal from overpowering the original signal.” (我们利用调节 token 来调节特征 G v G_v Gv G t G_t Gt,这有助于防止多模态信号压倒原始信号。)
    • 公式 (13):
      Z ^ v i = ϕ v u i ( G v i × σ ( r ˉ v ) ) + Z v i \hat{Z}_v^i = \phi_{vu}^i (G_v^i \times \sigma(\bar{\textbf{r}}_v)) + Z_v^i Z^vi=ϕvui(Gvi×σ(rˉv))+Zvi
      Z ^ t i = ϕ t u i ( G t i × σ ( r ˉ t ) ) + Z t i \hat{Z}_t^i = \phi_{tu}^i (G_t^i \times \sigma(\bar{\textbf{r}}_t)) + Z_t^i Z^ti=ϕtui(Gti×σ(rˉt))+Zti
      • 这里的 r ˉ v \bar{\textbf{r}}_v rˉv r ˉ t \bar{\textbf{r}}_t rˉt 是从 CAMF 模块的 MHCA 中,由原始调节 token r v \textbf{r}_v rv r t \textbf{r}_t rt 作为查询的一部分与另一模态交互后得到的输出。它们可以看作是“调节因子”或“门控信号”。
      • σ ( ⋅ ) \sigma(\cdot) σ() 是 sigmoid 函数,将调节因子的值缩放到 (0, 1) 之间,使其可以作为一种“权重”或“门”。
      • G v i × σ ( r ˉ v ) G_v^i \times \sigma(\bar{\textbf{r}}_v) Gvi×σ(rˉv):表示将 TR 模块输出的视觉特征 G v i G_v^i Gvi 与这个视觉调节门 σ ( r ˉ v ) \sigma(\bar{\textbf{r}}_v) σ(rˉv) 进行元素相乘 (element-wise multiplication)。这相当于对 G v i G_v^i Gvi 的每个通道或维度进行加权,调节了 G v i G_v^i Gvi 的强度或重要性。文本侧同理。
      • ϕ v u i ( ⋅ ) \phi_{vu}^i(\cdot) ϕvui() ϕ t u i ( ⋅ ) \phi_{tu}^i(\cdot) ϕtui() 是 MLP 层,用于恢复通道维度。
      • 最后通过残差连接加回到原始的 CLIP 主干特征 Z v i Z_v^i Zvi Z t i Z_t^i Zti 上。

作用总结:

  1. 在 CAMF 模块中作为“探测器”/“引导器”

    • r v \textbf{r}_v rv(与 Q i − 1 , F v i \textbf{Q}^{i-1}, \textbf{F}_v^i Qi1,Fvi 一起)去“探测”文本特征 F t i \textbf{F}_t^i Fti,以获取文本中与当前视觉信息和查询相关的上下文,生成调节因子 r ˉ v \bar{\textbf{r}}_v rˉv
    • r t \textbf{r}_t rt(与 F t i \textbf{F}_t^i Fti 一起)去“探测”视觉特征 F v i \textbf{F}_v^i Fvi,以获取图像中与当前文本信息相关的上下文,生成调节因子 r ˉ t \bar{\textbf{r}}_t rˉt
      这些调节因子 r ˉ v , r ˉ t \bar{\textbf{r}}_v, \bar{\textbf{r}}_t rˉv,rˉt 本身就携带了跨模态交互后的信息。
  2. 在最终特征输出前作为“平衡器”/“门控调节器” (公式 13):

    • r ˉ v \bar{\textbf{r}}_v rˉv r ˉ t \bar{\textbf{r}}_t rˉt 经过 sigmoid 函数后,变成门控信号,分别与 TR 模块处理后的视觉特征 G v i G_v^i Gvi 和文本特征 G t i G_t^i Gti 相乘。
    • 核心目的:调节 G v i G_v^i Gvi G t i G_t^i Gti 的强度。正如文中所述,这有助于**“防止多模态信号压倒原始信号”**。意思是,在多模态交互过程中,新融合进来的信息(来自另一模态的信息)可能会过于强烈,导致原模态自身的重要信息被削弱或忽略。这两个调节 token 通过学习到的门控值,可以适当地缩放这些融合后的特征,使得与原始主干特征 Z v i , Z t i Z_v^i, Z_t^i Zvi,Zti 的残差连接更加平衡和有效。

为什么要引入?

  1. 更精细的跨模态交互控制:不仅仅是简单地融合特征,而是通过可学习的 token 来引导和调节这个融合过程。
  2. 防止信息丢失或压制:在深度网络和复杂的特征融合中,保持原始信息和新信息的平衡非常重要。这些调节 token 提供了一种机制来动态调整这种平衡,防止某一模态的信息在融合后被另一模态完全主导。
  3. 提升特征的表达能力和任务适应性:通过学习,这些调节 token 可以帮助模型生成对当前特定任务(如指代表达理解)更有判别力的特征表示。

如何初始化?

虽然论文中没有明确说明 r v \textbf{r}_v rv r t \textbf{r}_t rt 的初始化,但它们作为“可学习的 (learnable)” token,最常见的初始化方式是随机初始化(例如,从均值为0、标准差较小的高斯分布中采样)。文中提到可学习查询 Q \textbf{Q} Q 是随机初始化的 (“We randomly initialize N q N_q Nq learnable queries Q \textbf{Q} Q…”),因此这两个调节 token 很可能也采用类似的随机初始化策略。这样它们可以在训练过程中从数据中学习到如何扮演好各自的“探测器”和“平衡器”角色。

附5:为什么要引入Q’

引入 Q ′ \textbf{Q}' Q 这个东西主要是为了在解码阶段更好地利用先前在查询自适应模块 (QA module) 中学习到的“指代性查询” (referential query) Q \textbf{Q} Q 的先验知识,同时又为解码器提供一组“干净”的、可塑性强的查询来进行最终的目标定位。

  1. 什么是 Q \textbf{Q} Q

    • 在前面的 QA 模块中,一组初始查询 Q 0 \textbf{Q}^0 Q0 经过与多层级的图像特征和文本特征的迭代交互,最终会得到一组“指代性查询” Q \textbf{Q} Q
    • 这个 Q \textbf{Q} Q 已经学习到了大量关于文本描述所指代的特定目标应该具有什么样的特征和上下文信息。它可以被看作是解码器进行目标定位的一个非常有价值的“先验知识”或“初始猜测”。
  2. 为什么需要 Q ′ \textbf{Q}' Q

    • 利用先验知识:直接丢弃 Q \textbf{Q} Q 会浪费掉 QA 模块学习到的宝贵信息。所以,解码器需要一种方式来利用 Q \textbf{Q} Q
    • 避免干扰,提供可塑性:虽然 Q \textbf{Q} Q 包含了有用的先验,但它可能也“固化”了一些信息。如果直接用 Q \textbf{Q} Q 作为解码器的唯一查询,可能会限制解码器在最终融合特征图 H m m \textbf{H}_{mm} Hmm 上进行精细化搜索和定位的灵活性。解码器可能需要一组更“中性”或“空白”的查询,以便在新的、融合后的多模态特征上从头开始学习如何精确地提取目标嵌入 O \textbf{O} O
    • 控制信息注入方式:通过引入 Q ′ \textbf{Q}' Q,模型可以更灵活地控制如何将 Q \textbf{Q} Q 的先验知识与新的查询 Q ′ \textbf{Q}' Q 结合。
  3. Q ′ \textbf{Q}' Q 是如何工作的?

    • 初始化 Q ′ \textbf{Q}' Q

      • “Following, we first initialize the queries Q ′ \textbf{Q}' Q with the same size as the referential query Q \textbf{Q} Q…” ( Q ′ \textbf{Q}' Q 的尺寸和 Q \textbf{Q} Q 一样,确保它们可以相加。)
      • “…and add them together to utilize the prior context in Q \textbf{Q} Q.” (将 Q \textbf{Q} Q Q ′ \textbf{Q}' Q 相加,以此方式注入 Q \textbf{Q} Q 的先验知识。)
      • 关键点:“Note that, to avoid interference from Q ′ \textbf{Q}' Q during the initial stage, we initialize Q ′ \textbf{Q}' Q as an all-zero matrix.” (为了避免 Q ′ \textbf{Q}' Q 在初始阶段对 Q \textbf{Q} Q 的先验造成干扰, Q ′ \textbf{Q}' Q 被初始化为全零矩阵。)
        • 这意味着,在第一次跨注意力操作(公式18)之前, ϕ q ( Q ) + Q ′ \phi_q(\textbf{Q}) + \textbf{Q}' ϕq(Q)+Q 实际上等于 ϕ q ( Q ) \phi_q(\textbf{Q}) ϕq(Q)(因为 Q ′ \textbf{Q}' Q 是零)。 ϕ q ( ⋅ ) \phi_q(\cdot) ϕq() 是一个 MLP 层,用于调节 Q \textbf{Q} Q 的重要性。所以,初始的查询主要是由 Q \textbf{Q} Q 引导的。
    • 与多模态特征交互

      • 这些组合后的查询 ( ϕ q ( Q ) + Q ′ \phi_q(\textbf{Q}) + \textbf{Q}' ϕq(Q)+Q) 与视觉-语言融合特征图 H v m l , \textbf{H}_{vml}, Hvml以及文本特征 H t \textbf{H}_t Ht 一起输入到多头跨注意力层 (MHCA),以聚合条件信息并产生多模态特征图 H m m \textbf{H}_{mm} Hmm 和初步的目标相关查询 O ˉ c \bar{\textbf{O}}_c Oˉc。(公式 18, 19)
        • 在这里, Q ′ \textbf{Q}' Q 虽然初始为零,但它是可学习的。在训练过程中, Q ′ \textbf{Q}' Q 会逐渐学习到如何调整,以帮助解码器更好地从 H v m l \textbf{H}_{vml} Hvml H t \textbf{H}_t Ht 中提取信息。
    • 提取最终目标嵌入

      • 之后,经过层归一化和残差连接得到的查询 O c \textbf{O}_c Oc 会再次作为查询,去查询多模态特征图 H m m \textbf{H}_{mm} Hmm,以提取最终的目标嵌入 O \textbf{O} O。(公式 20, 21)
      • 在这个阶段, O c \textbf{O}_c Oc 已经融合了 Q \textbf{Q} Q 的先验和通过 Q ′ \textbf{Q}' Q 学习到的适应性调整。

总结来说,引入 Q ′ \textbf{Q}' Q 的原因和作用是:

  1. 提供可学习的“画布” Q ′ \textbf{Q}' Q 最初是“空白的”(全零),它提供了一组可学习的参数,让解码器有能力在 Q \textbf{Q} Q 提供的先验基础上进行微调和适应。它不像 Q \textbf{Q} Q 那样已经承载了大量来自 QA 模块的特定信息,因此具有更大的可塑性。
  2. 受控地利用先验知识:通过将 Q ′ \textbf{Q}' Q 与(可能经过 MLP 调节的) Q \textbf{Q} Q 相加,模型既利用了 Q \textbf{Q} Q 中学到的上下文,又允许 Q ′ \textbf{Q}' Q 在训练中学习如何最好地补充或调整这些先验,以适应解码阶段的特定需求。
  3. 避免初始干扰:将 Q ′ \textbf{Q}' Q 初始化为全零,确保在解码过程的早期,查询主要由经过 QA 模块提炼的 Q \textbf{Q} Q 主导,避免了随机初始化的 Q ′ \textbf{Q}' Q 可能带来的噪声或干扰。随着训练的进行, Q ′ \textbf{Q}' Q 会逐渐学习到有意义的值。
  4. 解耦先验和适应性:可以看作是将“先验指导”(来自 Q \textbf{Q} Q)和“当前阶段的适应性学习”(通过 Q ′ \textbf{Q}' Q)在一定程度上解耦,使得解码器既能站在“巨人”的肩膀上,又能灵活地“跳舞”。

因此, Q ′ \textbf{Q}' Q 充当了一个可学习的、初始中性的“适配器”或“增量调整器”,它与经验丰富的“老兵” Q \textbf{Q} Q 合作,共同引导解码器在融合后的多模态特征中精确地定位目标。

附6:伪代码

这份伪代码力求在结构上和关键步骤上贴合论文描述的核心方法。实际完整的实现会涉及更细致的维度匹配、模块参数设置、损失函数具体实现(包括二分图匹配)等。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 假设的CLIP模型 (冻结)
class FrozenCLIP(nn.Module):
    def __init__(self, visual_layers_to_return=(3, 5, 7, 9, 11), text_layers_to_return=(-1,)): # 示例层级
        super().__init__()
        # 此处应加载预训练的CLIP模型
        # self.visual_encoder = clip.load_visual_encoder(...)
        # self.text_encoder = clip.load_text_encoder(...)
        self.visual_layers_to_return = visual_layers_to_return # QA模块插入及特征提取层
        self.text_layers_to_return = text_layers_to_return # 文本特征提取层

        # 假设CLIP有L层视觉Transformer层
        self.visual_transformer_layers = nn.ModuleList([nn.Identity() for _ in range(12)]) # 简化表示
        self.text_transformer_layers = nn.ModuleList([nn.Identity() for _ in range(12)])   # 简化表示

        # 冻结参数
        for param in self.parameters():
            param.requires_grad = False

    def forward_visual_ μέχρι_layer(self, x, target_layer_idx):
        # 模拟CLIP视觉编码器逐层处理,并返回指定层的输出
        # ... CLIP 图像预处理和patch embedding ...
        # x_cls, x_patch_embed = ...
        # x = x_cls + x_patch_embed + pos_embed
        outputs = {}
        for i, layer in enumerate(self.visual_transformer_layers):
            x = layer(x)
            if i == target_layer_idx:
                outputs[i] = x
                return x, outputs # 简化:只返回目标层和当前所有捕获的输出
        return x, outputs # 返回最后一层如果target_layer_idx超出

    def forward_text_μέχρι_layer(self, text_tokens, target_layer_idx):
        # 模拟CLIP文本编码器逐层处理
        # ... CLIP 文本预处理和token embedding ...
        # x = token_embed + pos_embed
        outputs = {}
        for i, layer in enumerate(self.text_transformer_layers):
            x = layer(x)
            if i == target_layer_idx:
                outputs[i] = x
                return x, outputs
        return x, outputs

    def forward(self, image, text_tokens):
        # 为了简化,这里不完全模拟逐层交互,而是假设能获取各层特征
        # 实际RefFormer会在CLIP层之间插入QA模块
        
        # 提取文本特征 (通常取最后一层)
        # text_features_full, _ = self.forward_text_μέχρι_layer(text_tokens, self.text_transformer_layers.__len__() -1)
        # H_t = text_features_full
        # 假设直接获得最终文本特征
        H_t = torch.randn(text_tokens.size(0), text_tokens.size(1), 512) # B, L_text, D_clip

        # 提取视觉特征 (多层)
        # image_features_multilevel = {}
        # current_image_feat = image # 初始图像输入
        # for i in range(len(self.visual_transformer_layers)):
        #     current_image_feat, _ = self.forward_visual_μέχρι_layer(current_image_feat, i)
        #     if i in self.visual_layers_to_return:
        #         image_features_multilevel[i] = current_image_feat
        # 假设直接获得指定层级的视觉特征
        image_features_multilevel = {
            idx: torch.randn(image.size(0), 197, 768) for idx in self.visual_layers_to_return
        } # B, N_patches+1, D_clip_vision
        
        return image_features_multilevel, H_t


# 基础模块
class MHCA(nn.Module): # Multi-Head Cross-Attention
    def __init__(self, query_dim, kv_dim, embed_dim, num_heads):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, kdim=kv_dim, vdim=kv_dim, batch_first=True)
        self.q_proj = nn.Linear(query_dim, embed_dim) if query_dim != embed_dim else nn.Identity()
        # kv_dim is already embed_dim for key/value if coming from same source or projected outside

    def forward(self, query, key, value):
        query = self.q_proj(query)
        # key, value 已经是所需的维度
        attn_output, _ = self.multihead_attn(query, key, value)
        return attn_output

class MHSA(nn.Module): # Multi-Head Self-Attention
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)

    def forward(self, x):
        attn_output, _ = self.multihead_attn(x, x, x)
        return attn_output

class FFN(nn.Module): # FeedForward Network (MLP)
    def __init__(self, embed_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(embed_dim, hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(hidden_dim, embed_dim)

    def forward(self, x):
        return self.linear2(self.dropout(F.relu(self.linear1(x))))

# 查询自适应模块 (QA Module)
class QueryAdaptationModule(nn.Module):
    def __init__(self, clip_dim_v, clip_dim_t, query_dim, proj_dim_l, num_heads=8):
        super().__init__()
        self.proj_dim_l = proj_dim_l

        # 降维投影
        self.visual_down_proj = nn.Linear(clip_dim_v, proj_dim_l)
        self.text_down_proj = nn.Linear(clip_dim_t, proj_dim_l)

        # 可学习的调节token
        self.reg_token_v = nn.Parameter(torch.randn(1, 1, proj_dim_l))
        self.reg_token_t = nn.Parameter(torch.randn(1, 1, proj_dim_l))

        # CAMF (Condition Aggregation and Multi-modal Fusion)
        self.camf_cross_attn_vq_t = MHCA(proj_dim_l, proj_dim_l, proj_dim_l, num_heads) # (Query_vis, Queries_obj), Key_text, Value_text
        self.camf_cross_attn_t_v = MHCA(proj_dim_l, proj_dim_l, proj_dim_l, num_heads)  # Query_text, Key_vis, Value_vis
        self.camf_ln_vq = nn.LayerNorm(proj_dim_l)
        self.camf_ln_t = nn.LayerNorm(proj_dim_l)

        # TR (Target-related Context Refinement)
        self.tr_cross_attn_q_v = MHCA(proj_dim_l, proj_dim_l, proj_dim_l, num_heads) # Query_obj, Key_vis_enhanced, Value_vis_enhanced
        self.tr_ffn_q = FFN(proj_dim_l, proj_dim_l * 4)
        self.tr_ln_q = nn.LayerNorm(proj_dim_l)

        self.tr_self_attn_v = MHSA(proj_dim_l, num_heads)
        self.tr_ffn_v = FFN(proj_dim_l, proj_dim_l * 4)
        self.tr_ln_v = nn.LayerNorm(proj_dim_l)

        self.tr_self_attn_t = MHSA(proj_dim_l, num_heads)
        self.tr_ffn_t = FFN(proj_dim_l, proj_dim_l * 4)
        self.tr_ln_t = nn.LayerNorm(proj_dim_l)

        # 升维投影 (用于残差注入回CLIP主干)
        self.visual_up_proj = nn.Linear(proj_dim_l, clip_dim_v)
        self.text_up_proj = nn.Linear(proj_dim_l, clip_dim_t)
        
        # 用于辅助损失的输出头 (可选)
        self.aux_grounding_head = GroundingHead(proj_dim_l, num_classes=1, num_queries=3) # 假设 num_queries 固定

    def forward(self, Z_v_i, Z_t_i, Q_prev):
        """
        Z_v_i: 来自CLIP第i层视觉特征 (B, N_v, D_clip_v)
        Z_t_i: 来自CLIP第i层文本特征 (B, N_t, D_clip_t) (实践中可能用固定的最终文本特征)
        Q_prev: 上一个QA模块输出的查询 (B, N_q, D_l) 或初始查询
        """
        B, N_q, _ = Q_prev.shape
        
        # 1. 降维
        F_v_i = self.visual_down_proj(Z_v_i) # (B, N_v, D_l)
        F_t_i = self.text_down_proj(Z_t_i) # (B, N_t, D_l)

        # 准备调节token
        reg_v = self.reg_token_v.expand(B, -1, -1)
        reg_t = self.reg_token_t.expand(B, -1, -1)

        # 2. CAMF
        # Eq 7, 8
        camf_q_in_v = torch.cat([reg_v, Q_prev, F_v_i], dim=1)
        vq_t_fused = self.camf_cross_attn_vq_t(camf_q_in_v, F_t_i, F_t_i)
        
        # 残差和LN
        _reg_v_fused, _Q_c_i, _F_v_i_fused_by_t = torch.split(vq_t_fused, [1, N_q, F_v_i.size(1)], dim=1)
        Q_c_i = self.camf_ln_vq(_Q_c_i + Q_prev) # Query融合了文本信息
        F_v_i_fused_by_t = self.camf_ln_vq(_F_v_i_fused_by_t + F_v_i) # 视觉特征融合了文本信息
        reg_v_fused = _reg_v_fused # 更新后的调节token部分

        # Eq 9
        camf_q_in_t = torch.cat([reg_t, F_t_i], dim=1)
        t_v_fused = self.camf_cross_attn_t_v(camf_q_in_t, F_v_i, F_v_i) # 注意这里用原始 F_v_i
        _reg_t_fused, _F_t_i_fused_by_v = torch.split(t_v_fused, [1, F_t_i.size(1)], dim=1)
        F_t_i_fused_by_v = self.camf_ln_t(_F_t_i_fused_by_v + F_t_i) # 文本特征融合了视觉信息
        reg_t_fused = _reg_t_fused

        # 3. TR
        # Eq 10 ( refine query Q_c_i with F_v_i_fused_by_t )
        Q_r_i = self.tr_cross_attn_q_v(Q_c_i, F_v_i_fused_by_t, F_v_i_fused_by_t)
        Q_next = self.tr_ln_q(self.tr_ffn_q(Q_r_i) + Q_c_i) # 输出的参考性查询

        # Eq 11 (enhance visual features)
        v_in_tr = torch.cat([reg_v_fused, F_v_i_fused_by_t], dim=1)
        v_self_attended = self.tr_self_attn_v(v_in_tr)
        v_enhanced_tr = self.tr_ln_v(self.tr_ffn_v(v_self_attended) + v_in_tr)
        G_v_reg, G_v_i = torch.split(v_enhanced_tr, [1, F_v_i.size(1)], dim=1)


        # Eq 12 (enhance text features)
        t_in_tr = torch.cat([reg_t_fused, F_t_i_fused_by_v], dim=1)
        t_self_attended = self.tr_self_attn_t(t_in_tr)
        t_enhanced_tr = self.tr_ln_t(self.tr_ffn_t(t_self_attended) + t_in_tr)
        G_t_reg, G_t_i = torch.split(t_enhanced_tr, [1, F_t_i.size(1)], dim=1)

        # 4. 升维与特征调制 (用于注入回CLIP)
        # Eq 13
        Z_v_i_hat = self.visual_up_proj(G_v_i * torch.sigmoid(G_v_reg)) # 特征调制
        Z_t_i_hat = self.text_up_proj(G_t_i * torch.sigmoid(G_t_reg))
        
        # QA模块的辅助输出 (用于辅助损失)
        aux_box_preds, aux_cls_preds = self.aux_grounding_head(Q_next, G_v_i) # 用QA模块内的视觉特征

        return Z_v_i_hat, Z_t_i_hat, Q_next, (aux_box_preds, aux_cls_preds)


# 任务特定解码器 (简化版 DETR Decoder)
class RefFormerDecoder(nn.Module):
    def __init__(self, num_layers, embed_dim, num_heads, num_queries):
        super().__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(embed_dim, num_heads) for _ in range(num_layers)
        ])
        self.num_queries = num_queries
        # 解码器自身的查询 (论文中提到初始化为0,与参考性查询相加)
        self.decoder_queries_init = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
        self.query_mlp = nn.Linear(embed_dim, embed_dim) # 用于调节参考性查询Q的重要性 (phi_q)

    def forward(self, referential_queries_Q, lang_aware_multi_level_vis_feat, H_t):
        """
        referential_queries_Q: 来自最后一个QA模块的参考性查询 (B, N_q, D_l)
        lang_aware_multi_level_vis_feat: 融合了语言信息的多层级视觉特征 (B, N_total_v, D_l)
        H_t: 文本特征 (B, N_t, D_l)
        """
        B = referential_queries_Q.size(0)
        # Eq 18: 初始化解码器查询并与参考性查询结合
        # O_c in paper; Q_dec here
        Q_dec = self.query_mlp(referential_queries_Q) + self.decoder_queries_init.expand(B, -1, -1)
        
        target_embeddings = Q_dec
        for layer in self.layers:
            target_embeddings = layer(target_embeddings, lang_aware_multi_level_vis_feat, H_t)
        
        return target_embeddings # (B, N_q, D_l)

class DecoderLayer(nn.Module): # 单层解码器
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        # self.self_attn = MHSA(embed_dim, num_heads) # DETR style self-attn on queries
        # self.ln1 = nn.LayerNorm(embed_dim)
        
        # Eq 18, 19: concatenate queries with image features to interact with language
        # 论文的解码器结构与标准DETR稍有不同,直接将查询与图像特征拼接后和文本特征做交叉注意力
        # 然后再用更新后的查询和多模态特征图做交叉注意力
        self.cross_attn_q_img_text = MHCA(embed_dim, embed_dim, embed_dim, num_heads) # Q=[Q_dec; ImgFeat], KV=TextFeat
        self.ln_q_img_text_q = nn.LayerNorm(embed_dim)
        self.ln_q_img_text_img = nn.LayerNorm(embed_dim)

        # Eq 20, 21
        self.cross_attn_q_multimodal = MHCA(embed_dim, embed_dim, embed_dim, num_heads) # Q=O_c, KV=H_mm
        self.ffn = FFN(embed_dim, embed_dim * 4)
        self.ln_final_q = nn.LayerNorm(embed_dim)


    def forward(self, Q_dec, vis_feat, text_feat):
        # 简化实现论文中的解码逻辑
        # Step 1: (Eq 18, 19) Queries O_c and multi-modal feature H_mm generation
        num_queries = Q_dec.size(1)
        q_img_concat = torch.cat([Q_dec, vis_feat], dim=1)
        
        # 假设文本特征 H_t 是key和value
        fused_q_img = self.cross_attn_q_img_text(q_img_concat, text_feat, text_feat)
        O_c_bar, H_mm_bar = torch.split(fused_q_img, [num_queries, vis_feat.size(1)], dim=1)
        
        O_c = self.ln_q_img_text_q(O_c_bar + Q_dec) # Residual for queries
        H_mm = self.ln_q_img_text_img(H_mm_bar + vis_feat) # Residual for image features (now multimodal)

        # Step 2: (Eq 20, 21) Extract target embeddings O_hat
        O_bar = self.cross_attn_q_multimodal(O_c, H_mm, H_mm)
        # 论文中 O = LN(phi_r(O_bar)) + O_bar; phi_r是线性层。这里简化为FFN
        O_hat = self.ln_final_q(self.ffn(O_bar) + O_bar) # 论文用的是 MLP(LN(O_bar)) + O_bar
        
        return O_hat


# 定位头
class GroundingHead(nn.Module):
    def __init__(self, embed_dim, num_classes=1, num_queries=3): # num_classes=1 for confidence
        super().__init__()
        self.box_mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.ReLU(),
            nn.Linear(embed_dim, embed_dim), nn.ReLU(),
            nn.Linear(embed_dim, 4) # x,y,h,w
        )
        self.cls_mlp = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.ReLU(),
            nn.Linear(embed_dim, num_classes) # confidence score
        )
        self.num_queries = num_queries

    def forward(self, target_embeddings, visual_features_for_mask=None, produce_mask=False):
        # target_embeddings: (B, N_q, D)
        pred_boxes = self.box_mlp(target_embeddings).sigmoid() # Sigmoid for normalized coords
        pred_logits = self.cls_mlp(target_embeddings) # Confidence

        if produce_mask and visual_features_for_mask is not None:
            # 扩展到密集定位 (MaskFormer like)
            # target_embeddings (N_q, D) -> mask_embeddings (N_q, D_mask)
            # visual_features_for_mask (H_feat * W_feat, D_mask)
            # pred_masks (N_q, H_feat * W_feat) -> (N_q, H_img, W_img) via upsampling
            # mask_embed = self.mask_embed_mlp(target_embeddings)
            # pred_masks = torch.einsum("bqd,bpd->bqp", mask_embed, visual_features_for_mask)
            pred_masks = None # Placeholder for segmentation mask logic
            return pred_boxes, pred_logits, pred_masks
            
        return pred_boxes, pred_logits


# RefFormer 主模型
class RefFormer(nn.Module):
    def __init__(self, clip_model_name="ViT-B/32", num_queries=3, proj_dim_l=256, 
                 qa_insert_layers=(3,5,7,9,11), # 对应CLIP视觉层索引,0-indexed
                 decoder_layers=6, num_heads=8, clip_dim_v=768, clip_dim_t=512):
        super().__init__()
        self.clip_frozen = FrozenCLIP(visual_layers_to_return=qa_insert_layers, text_layers_to_return=(-1,))
        self.num_queries = num_queries
        self.proj_dim_l = proj_dim_l
        self.qa_insert_layers = qa_insert_layers
        
        # 初始查询 (随机初始化或来自语言嵌入)
        self.initial_referential_queries = nn.Parameter(torch.randn(1, num_queries, proj_dim_l))

        # QA模块列表
        self.qa_modules = nn.ModuleList([
            QueryAdaptationModule(clip_dim_v, clip_dim_t, proj_dim_l, proj_dim_l, num_heads)
            for _ in qa_insert_layers
        ])

        # 语言引导的多层级视觉特征融合模块
        self.multi_level_fusion_mhca = MHCA(proj_dim_l, clip_dim_t, proj_dim_l, num_heads) # Vis_k, Text_sos, Text_sos
        self.multi_level_fusion_proj = nn.Linear(proj_dim_l * len(qa_insert_layers), proj_dim_l) # 调整维度

        # 任务特定解码器
        self.decoder = RefFormerDecoder(decoder_layers, proj_dim_l, num_heads, num_queries)
        
        # 定位头
        self.grounding_head = GroundingHead(proj_dim_l, num_classes=1, num_queries=num_queries)

    def forward(self, image_input, text_tokens):
        B = image_input.size(0)
        
        # 1. CLIP 特征提取 (简化流程,实际是逐层通过CLIP并插入QA)
        # 在真实实现中,CLIP的forward会更复杂,以允许QA模块的注入和残差更新
        
        # 模拟从CLIP的文本编码器获取最终文本特征 H_t (D_clip_t) 和 Z_t_i (各层文本特征)
        # Z_t_last = self.clip_frozen.text_encoder(text_tokens) # (B, N_t, D_clip_t)
        # H_t_for_decoder = self.text_feat_proj_for_decoder(Z_t_last) # (B, N_t, D_l) if needed
        # 假设 Z_t_last 和 H_t_for_decoder 维度已处理好
        # 简单起见,QA模块中也用 Z_t_last
        Z_t_last = torch.randn(B, text_tokens.size(1), 512) # 假设这是CLIP文本输出
        H_t_for_decoder = F.adaptive_avg_pool1d(Z_t_last.transpose(1,2),1).squeeze(-1) # 简化:用全局文本特征给多层视觉融合
        H_t_for_decoder_expanded = H_t_for_decoder.unsqueeze(1).repeat(1, text_tokens.size(1),1) # (B, N_t, D) for decoder MHCA

        # 2. 通过CLIP视觉层和QA模块,迭代更新查询和特征
        current_Q = self.initial_referential_queries.expand(B, -1, -1)
        
        # Z_v_from_clip_layer = self.clip_frozen.visual_encoder.patch_embed(image_input)
        # Z_v_from_clip_layer = self.clip_frozen.visual_encoder.cls_token + Z_v_from_clip_layer + self.clip_frozen.visual_encoder.pos_embed
        Z_v_current = torch.randn(B, 197, 768) # 模拟初始视觉输入 (patch_embed + cls + pos_embed)

        all_qa_aux_outputs = []
        multi_level_visual_features_from_qa = [] # 存储来自QA的 Z_v_i_hat (已降维)

        # 模拟CLIP视觉编码器逐层计算,并在指定层后插入QA
        # 注意:CLIP模型自身是冻结的,QA模块是可训练的,并通过残差连接影响送入下一CLIP(冻结)层的特征
        clip_visual_layers = self.clip_frozen.visual_transformer_layers # 假设可以访问
        
        qa_module_idx = 0
        for i, clip_layer in enumerate(clip_visual_layers):
            # 通过冻结的CLIP层
            Z_v_current = clip_layer(Z_v_current) # (B, N_v, D_clip_v)
            
            if i in self.qa_insert_layers:
                qa_mod = self.qa_modules[qa_module_idx]
                
                # QA模块的输入 Z_t_i 通常是CLIP文本编码器对应层的输出,
                # 但论文图示和简化实现常使用最终文本特征 Z_t_last
                Z_v_res_update, Z_t_res_update, current_Q, aux_preds = qa_mod(Z_v_current, Z_t_last, current_Q)
                
                all_qa_aux_outputs.append(aux_preds)
                
                # 残差式更新送入下一CLIP层的特征 (Z_v_i_hat 是升维后的)
                Z_v_current = Z_v_current + Z_v_res_update 
                # Z_t_last = Z_t_last + Z_t_res_update # 如果文本特征也逐层更新

                # 收集QA模块输出的视觉特征 (降维后的 G_v_i) 用于后续多层融合
                # 这里需要从qa_mod内部获取G_v_i,或者让qa_mod也返回它
                # 假设 qa_mod 返回了 G_v_i (proj_dim_l维度)
                _, _, _, G_v_i_for_fusion = qa_mod(Z_v_current, Z_t_last, current_Q) # 再次调用或修改返回
                G_v_i_for_fusion = torch.randn(B, Z_v_current.size(1), self.proj_dim_l) # 占位
                multi_level_visual_features_from_qa.append(G_v_i_for_fusion)
                
                qa_module_idx += 1

        final_referential_queries_Q = current_Q # (B, N_q, D_l)

        # 3. 语言引导的多层级视觉特征融合 (Eq 14-17)
        # H_t_sos = Z_t_last[:, 0:1, :] # 取[SOS] token作为全局文本信息 (D_clip_t)
        # 假设 H_t_sos 已经投影到 D_l
        H_t_sos_proj = self.text_down_proj(Z_t_last[:, 0:1, :]) # (B, 1, D_l)

        fused_vis_levels = []
        for Z_v_k_hat in multi_level_visual_features_from_qa: # Z_v_k_hat 是 (B, N_v, D_l)
            H_v_k_fused = self.multi_level_fusion_mhca(Z_v_k_hat, H_t_sos_proj, H_t_sos_proj)
            fused_vis_levels.append(H_v_k_fused + Z_v_k_hat) # 残差
        
        # 拼接并投影
        concatenated_fused_vis = torch.cat(fused_vis_levels, dim=1) # (B, N_total_v, D_l)
        # 论文中没有这一步投影,而是直接用拼接的特征,或者每一级特征分别处理
        # 假设解码器能处理拼接后的特征,或者有进一步处理
        # H_vml = self.multi_level_fusion_proj(concatenated_fused_vis) # (B, N_some_v, D_l)
        H_vml = concatenated_fused_vis # 解码器内部可能再处理长度

        # 4. 任务特定解码器
        # H_t_for_decoder 应该是 (B, N_t, D_l)
        # 假设 Z_t_last 经过了合适的投影
        H_t_for_decoder_proj = self.text_down_proj(Z_t_last) if Z_t_last.size(-1) != self.proj_dim_l else Z_t_last

        target_embeddings = self.decoder(final_referential_queries_Q, H_vml, H_t_for_decoder_proj)
        
        # 5. 定位头
        # 如果需要分割,H_vml 可以作为 visual_features_for_mask (可能需要reshape和投影)
        pred_boxes, pred_logits = self.grounding_head(target_embeddings)
        
        outputs = {
            "pred_boxes": pred_boxes, # (B, N_q, 4)
            "pred_logits": pred_logits, # (B, N_q, num_classes)
            "aux_outputs": all_qa_aux_outputs # List of (box, cls) tuples from QA modules
        }
        return outputs

# --- 主程序流程概念 ---
# model = RefFormer()
# image = torch.randn(2, 3, 640, 640) # Batch of 2 images
# text_tokens = torch.randint(0, 1000, (2, 77)) # Batch of 2 tokenized texts

# outputs = model(image, text_tokens)

# pred_boxes = outputs["pred_boxes"]
# pred_logits = outputs["pred_logits"]

# # 计算损失 (包括主损失和辅助损失)
# # loss_main = compute_detection_loss(pred_boxes, pred_logits, ground_truth_boxes, ground_truth_labels)
# # loss_aux = 0
# # for aux_out in outputs["aux_outputs"]:
# #     loss_aux += compute_detection_loss(aux_out[0], aux_out[1], ground_truth_boxes, ground_truth_labels) # 辅助损失也用主GT
# # total_loss = loss_main + lambda_aux * loss_aux
# # total_loss.backward()

  1. 冻结CLIP骨干FrozenCLIP 类代表了这一点,其参数不参与训练。
  2. 查询自适应模块 (QA)QueryAdaptationModule 是核心。
    • 降维/升维visual_down_proj, text_down_proj, visual_up_proj, text_up_proj
    • CAMF 与 TR 块:内部通过MHCAMHSA以及FFN实现多模态信息融合和上下文优化。
    • 参考性查询迭代优化Q_prev 输入和 Q_next 输出,在RefFormerforward循环中体现。
    • 调节token和特征调制reg_token_v/ttorch.sigmoid(G_v_reg) 的使用。
    • 残差注入Z_v_current = Z_v_current + Z_v_res_update 体现了QA模块作为适配器,将任务知识注入回(概念上的)CLIP流。
    • 辅助损失输出aux_grounding_head 用于监督中间的参考性查询。
  3. 多层级特征利用
    • QA模块被插入到CLIP视觉编码器的多个指定层 (qa_insert_layers)。
    • multi_level_visual_features_from_qa 收集来自不同QA模块的视觉特征。
    • Language-guided Multi-level Fusion部分(multi_level_fusion_mhca)融合这些特征并注入语言信息。
  4. 参考性查询引导解码
    • RefFormerDecoder接收最后一个QA模块输出的final_referential_queries_Q
    • 解码器查询通过self.query_mlp(referential_queries_Q) + self.decoder_queries_init结合了先验信息。
  5. DETR类解码与输出
    • RefFormerDecoderGroundingHead 遵循了DETR的思路进行目标框和置信度的预测。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

frostmelody

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值