Text to image论文精读 NAAF:基于负感知注意力的图像-文本匹配框架 Negative-Aware Attention Framework for Image-Text Matching

NAAF:Negative-Aware Attention Framework for Image-Text Matching是基于负感知注意力的图文匹配,其利用匹配片段的积极影响和不匹配片段的消极影响来共同推断图像-文本的相似性,文章由中国科学技术大学和北京邮电大学学者在2022CVPR上发表。

论文地址:https://ieeexplore.ieee.org/document/9879764
代码地址:https://github.com/CrossmodalGroup/NAAF.
作者博客地址:https://www.cnblogs.com/lemonzhang/p/16456403.html

在这里插入图片描述
注意:这篇论文主要讨论的是图像和文本的匹配,即语义一致性的方法。不是专用于做文本生成图像的系列论文,本篇文章是阅读这篇论文的精读笔记。

一、原文摘要

图文匹配作为一项基本任务,弥合了视觉和语言之间的鸿沟。这项任务的关键是准确测量这两种模式之间的相似性。先前的工作主要基于匹配的片段(即,具有高相关性的单词/区域)来测量这种相似性,同时低估甚至忽略了不匹配的片段的影响(即,低相关性的单词或区域),例如,通过典型的LeaklyReLU或ReLU操作,该操作迫使负分数接近或精确到零。这项工作认为,不匹配的文本片段(包含丰富的不匹配线索)对图像文本匹配也至关重要。

因此,我们提出了一种新的消极意识注意框架(NAAF),该框架明确地利用匹配片段的积极影响和不匹配片段的消极影响来共同推断图像-文本的相似性。NAAF(1)精心设计了一种迭代优化方法,以最大限度地挖掘不匹配的片段,促进更具辨别力和鲁棒性的负面影响,(2)设计了双分支匹配机制,以精确计算具有不同掩码的匹配/不匹配片段的相似性/不相似性程度。在两个基准数据集(即Flickr30K和MSCOCO)上进行的大量实验证明了我们的NAAF的卓越性能,达到了最先进的性能。

二、图像-文本匹配

图像文本匹配任务定义:也称为跨模态图像文本检索,即通过某一种模态实例, 在另一模态中检索语义相关的实例。例如,给定一张图像,查询与之语义对应的文本,反之亦然。具体而言,对于任意输入的文本-图像对(Image-Text Pair),图文匹配的目的是衡量图像和文本之间的语义相似程度(这也是文本生成图像中很重要的一个点)。

图片和文字由于模态的异构,存在极大的语义鸿沟。图文匹配的关键挑战在于准确学习图片和文本之间的语义对应关系,并度量它们的相似性。在现有的图像文本匹配方法中有两种范式:

  1. 第一种方法倾向于执行全局级匹配,即找到文本和整个图像之间的语义对应。他们通常将整体图像和文本投射到一个共同的潜在空间,然后匹配这两种模式。(CLIP可以被分为此类,其将图像和文本同时投影到一个计算矩阵中,计算其相似度)。
  2. 第二种范式侧重于检查局部级匹配,即图像中的显著区域和文本中的单词之间的匹配。局部级别匹配考虑了图像和文本之间的细粒度语义对应AttnGAN的DAMSM就是基于这个原理,其将句子的图像和单词的子区域映射到一个公共语义空间,从而在单词级别测量图像-文本相似度,以计算图像生成的细粒度损失。

在局部级匹配的领域,基于注意力的匹配框架最近迅速成为主流,其关键思想是通过注意力关注来自另一模态的每个查询片段的相关片段来发现所有单词-图像区域对齐

三、为什么提出NAAF?

显然匹配的片段(即,具有高相关性分数的单词区域对)将对最终的图像-文本相似性做出很大贡献,而不匹配片段(即具有低相关性分数的词区域对)的影响将被削弱甚至消除,例如,通过在注意力过程中迫使负分数接近或精确为零的典型LeakyReLUReLU

现有的很多方法主要寻找匹配的片段,而低估或忽略了不匹配片段的影响,完全忽略了不匹配的文本片段在证明图像文本不匹配中的关键作用,将不可避免地容易产生假阳性匹配:

假阳性匹配
包含许多匹配片段但有一些不匹配的文本片段的图像-文本对(直接表明图像-文本不匹配)仍然可以获得高相似度,并且可以正确地排在最前面,这肯定不是一个令人满意的结果,比如说下例两个男孩在一些树旁的路上踢足球:现有的方法主要寻找匹配的片段,例如“男孩”、“树”,以计算图像-文本(I-T)相似性,而不匹配的片段(例如“足球”)的影响被典型的LeaklyReLUReLU削弱或忽略,显然这并不是一个很优秀的匹配,但是由于他在大部分关键词上匹配得分高,其匹配结果会非常靠前,这就属于假阳性。
在这里插入图片描述

其主要集中于最大化匹配(即对齐)片段的效果,而低估或忽略了不匹配片段的线索作用。而合理的匹配框架应该同时考虑两个方面,即图像文本对的总体匹配分数不仅由匹配片段的积极影响决定,而且由不匹配片段的消极影响决定。可以充分的挖掘非对齐片段的负面作用,使原本检索在Top位置的错误匹配降低相似分值,对图像匹配度进行减分,如下图所示,就可以很容易消除假阳性。
在这里插入图片描述

因此,作者提出了一种新的负感知注意力框架,该框架首次明确考虑了正匹配和负不匹配的片段,以联合测量图像-文本的相似性。与片面关注匹配片段的传统匹配机制不同,该注意力框架可以有效地挖掘不匹配的文本片段,以进一步利用这两种类型的线索进行联合相似性推断。并使用它们准确地反映两种模式之间的差异。消极感知注意框架NAAF由两个模块组成:

  1. 设计了一种双分支匹配来解决不匹配片段的利用率不足的问题,它包含了不同掩码下的消极和积极注意,一种用于精确计算不匹配片段之间的相似度,另一种用于计算匹配片段之间相似度。分别测量精确的相似度/相异度,以联合推断整体图像-文本相似度。
  2. 提出了一种新的迭代优化方法来显式地建模和挖掘不匹配的片段。

四、NAAF

NAAF的总体框架如图所示,可以看到,首先Feature Extraction提取图像特征和文本特征(这里不再展开),然后Negative-aware Attntion使用负效应和正效应进行负意识注意以测量图像和文本的相似性,其包括两个主要模块,用于显式地利用负不匹配和正匹配的文本片段来联合推断图像-文本相似性。1.不匹配挖掘模块使失配线索产生更稳健的负面影响。2.正负双分支匹配模块精确计算两种类型片段的正面和负面影响,从而测量总体相似性。
在这里插入图片描述

4.1、特征提取

  • 视觉表征:给定图像V,利用Visual Genome 预训练的FasterRCNN检测显著对象和其他区域。然后,通过预训练的ResNet-101过平均池卷积特征提取检测区域。采用全连接层将每个区域映射到1024维特征。
  • 文本表征:给定由m个单词组成的文本U,我们将每个单词热编码为1024维向量,并嵌入预先训练的GloVe向量中,然后,向量被馈送到双向门控循环单元(BiGRU)中,以整合前向和后向上下文信息。最终的单词表示 u i u_i ui是双向隐藏状态的平均值。

4.2、Negative-aware Attntion

给定一个图像-文本对,它可能包含丰富的匹配和不匹配片段。本模块的目标就是充分利用这两类线索,以实现更准确的匹配性能。在NAAF框架中主要有两个模块:

  1. 不匹配挖掘模块:旨在通过最小化训练过程中匹配和不匹配相似性分布之间错误重叠的惩罚概率,明确建模和最大限度地挖掘不匹配片段。
  2. 正负双分支匹配模块:旨在通过设计的两个分支匹配,即负和正注意分支,精确计算负失配和正匹配的影响,以共同推断相似性。

1️⃣:不匹配挖掘模块

不匹配挖掘模块期望显式地和自适应地建模失配和匹配片段的相似性分布,旨在最大限度地分离它们,以实现有效的不匹配片段挖掘。
在这里插入图片描述

为此,在训练过程中,对于不匹配和匹配的单词区域片段对,首先对它们的相似度进行采样:

S k − = [ s 1 − , s 2 − , s 3 − , … , s i − , … ] S k + = [ s 1 + , s 2 + , s 3 + , … , s i + , … ] \begin{aligned} S_{k}^{-} &=\left[s_{1}^{-}, s_{2}^{-}, s_{3}^{-}, \ldots, s_{i}^{-}, \ldots\right] \\ S_{k}^{+} &=\left[s_{1}^{+}, s_{2}^{+}, s_{3}^{+}, \ldots, s_{i}^{+}, \ldots\right] \end{aligned} SkSk+=[s1,s2,s3,,si,]=[s1+,s2+,s3+,,si+,]
其中S-表示不匹配区域-单词的相似度分数,S+表示匹配区域-单词的相似度分数。

基于构造出的两个集合,可以分别建立匹配片段和不匹配片段的相似度分数s的概率分布模型:

分布模型公式表示为: f k − ( s ) = 1 σ k − 2 π e [ − ( s − μ k − ) 2 2 ( σ k − ) 2 ] , f k + ( s ) = 1 σ k + 2 π e [ − ( s − μ k + ) 2 2 ( σ k + ) 2 ] f_{k}^{-}(s)=\frac{1}{\sigma_{k}^{-} \sqrt{2 \pi}} e^{\left[-\frac{\left(s-\mu_{k}^{-}\right)^{2}}{2\left(\sigma_{k}^{-}\right)^{2}}\right]}, f_{k}^{+}(s)=\frac{1}{\sigma_{k}^{+} \sqrt{2 \pi}} e^{\left[-\frac{\left(s-\mu_{k}^{+}\right)^{2}}{2\left(\sigma_{k}^{+}\right)^{2}}\right]} fk(s)=σk2π 1e[2(σk)2(sμk)2],fk+(s)=σk+2π 1e[2(σk+)2(sμk+)2]
其中(µ−k,σ−k)(µ+k,σ+k)分别是两种分布的平均值和标准差:
在这里插入图片描述

分别得到两个相似度分布建模后,可以用一个显式的边界t在匹配片段和不匹配片段之间进行区分,如图所示,相似度分数大于 t k t_k tk的区域-单词对被视为匹配片段,反之则为不匹配片段,但是不可避免的就会出现两种误判:将实际上不匹配的片段区分为匹配的 和 将实际上匹配的片段误认为是不匹配的。而此模块的目的是最大限度的挖掘出不匹配片段,找出一个最优的边界t,使得区分错误的概率最低,保证识别的准确性,即解决如下优化问题:

min ⁡ t α ∫ t + ∞ f k − ( s ) d s + ∫ − ∞ t f k + ( s ) d s ,  s.t.  t ≥ 0 \begin{array}{ll} \min _{t} & \alpha \int_{t}^{+\infty} f_{k}^{-}(s) d s+\int_{-\infty}^{t} f_{k}^{+}(s) d s, \\ \text { s.t. } & t \geq 0 \end{array} mint s.t. αt+fk(s)ds+tfk+(s)ds,t0
其中t是该问题的决策变量,α是惩罚参数。

对于该问题的最优解求解,我们首先搜索它的一阶导数的零点,并根据可行域的约束条件在(t ≥ 0)处截断,得到最优解为:
t k = [ ( ( β 2 k 2 − 4 β 1 k β 3 k ) 1 2 − β 2 k ) / ( 2 β 1 k ) ] + 其中 β 1 k = ( σ k + ) 2 − ( σ k − ) 2 , β 2 k = 2 ( μ k + σ k − 2 − μ k − σ k + 2 ) ,   β 3 k = ( σ k + μ k − ) 2 − ( σ k − μ k + ) 2 + 2 ( σ k + σ k − ) 2 ln ⁡ σ k − α σ k + . \begin{array}{c} t_{k}=\left[\left(\left(\beta_{2}^{k^{2}}-4 \beta_{1}^{k} \beta_{3}^{k}\right)^{\frac{1}{2}}-\beta_{2}^{k}\right) /\left(2 \beta_{1}^{k}\right)\right]_{+} \\ \text {其中} \beta_{1}^{k}=\left(\sigma_{k}^{+}\right)^{2}-\left(\sigma_{k}^{-}\right)^{2}, \beta_{2}^{k}=2\left(\mu_{k}^{+} \sigma_{k}^{-2}-\mu_{k}^{-} \sigma_{k}^{+2}\right), \text { } \beta_{3}^{k}=\left(\sigma_{k}^{+} \mu_{k}^{-}\right)^{2}-\left(\sigma_{k}^{-} \mu_{k}^{+}\right)^{2}+2\left(\sigma_{k}^{+} \sigma_{k}^{-}\right)^{2} \ln \frac{\sigma_{k}^{-}}{\alpha \sigma_{k}^{+}} . \end{array} tk=[((β2k24β1kβ3k)21β2k)/(2β1k)]+其中β1k=(σk+)2(σk)2,β2k=2(μk+σk2μkσk+2), β3k=(σk+μk)2(σkμk+)2+2(σk+σk)2lnασk+σk.

2️⃣:正负双分支匹配模块

在这里插入图片描述
双分支框架可以同时关注图像-文本对中不匹配和匹配的片段,方法是使用不同的注意力掩码分别精确测量它们在负注意力和正注意力中的影响。

具体地说,首先计算所有单词和区域之间的语义相关性得分为:

s i j = u i v j T ∥ u i ∥ ∥ v j ∥ , i ∈ [ 1 , m ] , j ∈ [ 1 , n ] s_{i j}=\frac{u_{i} v_{j}^{\mathrm{T}}}{\left\|u_{i}\right\|\left\|v_{j}\right\|}, i \in[1, m], j \in[1, n] sij=uivjuivjT,i[1,m],j[1,n],这里原理与AttnGAN中的DAMSM类似。

然后使用不同的注意力掩码双线计算:

  1. 负注意力分支:这一模块的目标是准确有效地利用不匹配的片段,使它们有价值地降低不匹配图像-文本对的整体相似性。分支依然从文本的角度出发,计算一个文本单词和一个图像所有区域的相似度与区分边界 t k t_k tk的差,其中的最大值体现了这个片段是匹配还是不匹配的程度: s i = max ⁡ j ( { s i j − t k } j = 1 n ) s_{i}=\max _{j}\left(\left\{s_{i j}-t_{k}\right\}_{j=1}^{n}\right) si=maxj({sijtk}j=1n),由此,可以衡量出一个图像文本对中第个单词所带来的负面作用为: s i neg  = s i ⊙ Mask ⁡ neg  ( s i ) s_{i}^{\text {neg }}=s_{i} \odot \operatorname{Mask}_{\text {neg }}\left(s_{i}\right) sineg =siMaskneg (si) ,其中 M a s k n e g ( ⋅ ) Mask_{neg}(⋅) Maskneg()为掩码函数,当输入为负数时输出为1,否则为0。同时,考虑到单词在文本内的语义内关系,使语义相似的单词获得相同的匹配关系,在推理过程中,对每个单词的匹配程度进行一次模态内传播: s ^ i = ∑ l = 1 m w i l i n t r a s l , s.t.  w i l i n t r a = softmax ⁡ λ ( { u i u l T ∥ u i ∥ ∥ u l ∥ } l = 1 m ) \hat{s}_{i}=\sum_{l=1}^{m} w_{i l}^{i n t r a} s_{l} \text {, s.t. } w_{i l}^{i n t r a}=\operatorname{softmax}_{\lambda}\left(\left\{\frac{u_{i} u_{l}^{\mathrm{T}}}{\left\|u_{i}\right\|\left\|u_{l}\right\|}\right\}_{l=1}^{m}\right) s^i=l=1mwilintrasl, s.t. wilintra=softmaxλ({uiuluiulT}l=1m),其中 w i l i n t r a w^{intra}_{il} wilintra表示第i个和第l个单词之间的语义关系,λ是比例因子。
  2. 正注意力分支:该分支旨在测量图像-文本对的相似程度,首先关注跨模态的共享语义,第i个单词在图像中相关的共享语义可以被聚合为: w i j i n t e r = softmax ⁡ λ ( { Mask ⁡ pos  ( s i j − t k ) } j = 1 n ) w_{i j}^{i n t e r}=\operatorname{softmax}_{\lambda}\left(\left\{\operatorname{Mask}_{\text {pos }}\left(s_{i j}-t_{k}\right)\right\}_{j=1}^{n}\right) wijinter=softmaxλ({Maskpos (sijtk)}j=1n),其中 M a s k p o s ( ⋅ ) Mask_{pos}(⋅) Maskpos()为掩码函数,当输入为正数时输出与输入相等,否则输出 − ∞ −∞ ,这样使得不相关的图像区域的注意力权重被削减至0,由此,片段的相似度分数为 s i f = u i v ^ i T / ( ∥ u i ∥ ∥ v ^ i ∥ ) s_{i}^{f}=u_{i} \hat{v}_{i}^{\mathrm{T}} /\left(\left\|u_{i}\right\|\left\|\hat{v}_{i}\right\|\right) sif=uiv^iT/(uiv^i),另外,区域与单词间的相关度分数也反应了图文间的相似程度,故作者还根据单词的相应相关性得分计算加权相似度: s i r = ∑ j = 1 n w i j r e l e v s i j s_{i}^{r}=\sum_{j=1}^{n} w_{i j}^{r e l e v} s_{i j} sir=j=1nwijrelevsij 其中, w i j r e l e v = softmax ⁡ λ ( { s ˉ i j } j = 1 n ) , s ˉ i j = [ s i j ] + / ∑ i = 1 m [ s i j ] + 2 w_{i j}^{r e l e v}=\operatorname{softmax}_{\lambda}\left(\left\{\bar{s}_{i j}\right\}_{j=1}^{n}\right), \bar{s}_{i j}=\left[s_{i j}\right]+/ \sqrt{\sum_{i=1}^{m}\left[s_{i j}\right]_{+}^{2}} wijrelev=softmaxλ({sˉij}j=1n),sˉij=[sij]+/i=1m[sij]+2 ,因此,一个图像文本对中第个单词所带来的正面作用为: s i pos  = s i f + s i r s_{i}^{\text {pos }}=s_{i}^{f}+s_{i}^{r} sipos =sif+sir

最终,图像文本对 (U,V)的相似度由正面作用和负面作用共同决定 S ( U , V ) = 1 m ∑ i = 1 m ( s i neg  + s i pos  ) S(U, V)=\frac{1}{m} \sum_{i=1}^{m}\left(s_{i}^{\text {neg }}+s_{i}^{\text {pos }}\right) S(U,V)=m1i=1m(sineg +sipos )

4.3、采样和更新策略

  1. 对于对齐的单词,在正确的图像中至少有一个匹配区域。因此对单词 u i u_i ui,图像区域{ v j + v^+_j vj+} j = 1 n ^n_{j=1} j1n之间的最大相似性进行采样: s i + = max ⁡ j ( { v j + u i T / ( ∥ v j + ∥ ∥ u i ∥ ) } j = 1 n ) s_{i}^{+}=\max _{j}\left(\left\{v_{j}^{+} u_{i}^{\mathrm{T}} /\left(\left\|v_{j}^{+}\right\|\left\|u_{i}\right\|\right)\right\}_{j=1}^{n}\right) si+=maxj({vj+uiT/( vj+ ui)}j=1n)
  2. 对于未对齐的单词,不正确图像中的所有区域都与其不匹配。因此对单词 u i u_i ui,图像区域{ v j + v^+_j vj+} j = 1 n ^n_{j=1} j1n,对其采样为: s i − = max ⁡ j ( { v j − u i T / ( ∥ v j − ∥ ∥ u i ∥ ) } j = 1 n ) , s_{i}^{-}=\max _{j}\left(\left\{v_{j}^{-} u_{i}^{\mathrm{T}} /\left(\left\|v_{j}^{-}\right\|\left\|u_{i}\right\|\right)\right\}_{j=1}^{n}\right), si=maxj({vjuiT/( vj ui)}j=1n),

此外,为了对精确的伪词区域相似性标签进行采样,作者基于计算的相似度排名的正确性设计来决定是否更新 s i + s^+_i si+ s i − s^−_i si

4.4、损失函数

本文中用于端到端训练的目标函数是双向三元组排序损失,损失函数如下:

L = ∑ ( U , V ) [ γ − S ( U , V ) + S ( U , V ′ ) ] + + [ γ − S ( U , V ) + S ( U ′ , V ) ] + L=\sum_{(U, V)}\left[\gamma-S(U, V)+S\left(U, V^{\prime}\right)\right]_{+}+\left[\gamma-S(U, V)+S\left(U^{\prime}, V\right)\right]_{+} L=(U,V)[γS(U,V)+S(U,V)]++[γS(U,V)+S(U,V)]+

其中: (U, V )表示成功匹配的图像和匹配的文本,(U, V′)和(U′, V )表示未成功匹配的图像和文本。

五、实验

5.1、实验设置

1️⃣数据集:Flickr30K总共有31000张图片和155000个句子,其被分成1000张测试图像、1000张验证图像和29000张训练图像。MS-COCO包含123287张图像和616435个句子,将其分为5000张测试图像、5000张验证图像和113287张训练图像。

2️⃣评估指标:Recall(R@K,K=1,5,10)rSum。R@K表示检索到的前K个列表中的地面真相的百分比。rSum是所有R@K在图像到文本和文本到图像中,反映了整体匹配性能。

3️⃣实现细节:显卡为RTX 3090Ti GPU,优化器为Adam,初始学习率为0.0005,每10个周期衰减10%。Flickr30KMSCOCO的最小批量大小分别设置为128和256,两个数据集上都有20个epoches,特征尺寸d被设置为1024。λ设置为20,α设置为2.0,γ设置0.2。

5.2、实验结果

定量指标:
在这里插入图片描述
可视化最优阈值学习过程:
在这里插入图片描述
不匹配线索挖掘对比(蓝色为不匹配):
在这里插入图片描述

六、总结

这项工作的主要贡献总结如下。

1) 提出了一种新颖的双分支匹配模块,该模块联合利用不匹配和匹配的文本片段进行精确的图像文本匹配。与传统的关注不同,该方法可以同时关注失配和匹配片段,以明确地利用它们的负面和正面影响。双分支匹配机制能够分别测量精确的相似度/相异度,以联合推断整体图像-文本相似度

2) 我们提出了一种新的具有负挖掘策略的迭代优化方法,该方法可以以最大限度地挖掘负面失配片段,明确地驱动不匹配片段的更多负面影响,并从理论上保证挖掘的准确性,产生更全面和可解释的图像-文本相似性度量。

3) 在Flickr30K和MS-COCO两个基准上进行的大量实验表明,NAAF的表现优于比较方法。分析也充分证明了我们方法的优越性和合理性。

💡 最后

我们已经建立了🏤T2I研学社群,如果你对本文还有其他疑问或者对🎓文本生成图像/文本生成3D方向很感兴趣,可以点击下方链接或者私信我加入社群

📝 加入社群 抱团学习中杯可乐多加冰-采苓AI研习社

🔥 限时免费订阅文本生成图像T2I专栏

🎉 支持我:点赞👍+收藏⭐️+留言📝

  • 17
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 9
    评论
下面是一个示例的PyTorch代码,演示了如何使用CLIP文本编码器和ResNet-50模型来构建注意力模块的输入,并进行训练: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.models as models from transformers import CLIPModel, CLIPProcessor # 构建CLIP模型和处理器 clip_model = CLIPModel.from_pretrained('openai/clip-vit-base-patch32') clip_processor = CLIPProcessor.from_pretrained('openai/clip-vit-base-patch32') # 加载ResNet-50模型 resnet_model = models.resnet50(pretrained=True) resnet_model.fc = nn.Identity() # 去掉最后的全连接层 # 定义注意力模块 class AttentionModule(nn.Module): def __init__(self): super(AttentionModule, self).__init__() self.fc1 = nn.Linear(512, 256) self.fc2 = nn.Linear(512, 256) def forward(self, text_features, image_features): text_attention = self.fc1(text_features) image_attention = self.fc2(image_features) attention_scores = torch.matmul(text_attention, image_attention.T) attention_weights = torch.softmax(attention_scores, dim=1) attended_text_features = torch.matmul(attention_weights, text_features) attended_image_features = torch.matmul(attention_weights.T, image_features) return attended_text_features, attended_image_features # 创建注意力模块实例 attention_module = AttentionModule() # 定义损失函数和优化器 loss_fn = nn.MSELoss() optimizer = optim.Adam(attention_module.parameters(), lr=0.001) # 准备示例输入数据 text_input = "example text" image_input = torch.randn(1, 3, 224, 224) # 示例图像输入 # 进行输入数据的预处理 text_inputs = clip_processor(text_input, return_tensors="pt", padding=True) image_inputs = clip_processor(images=image_input, return_tensors="pt", padding=True) # 获取CLIP文本编码器的特征 with torch.no_grad(): text_features = clip_model.get_text_features(**text_inputs).to(device) # 获取ResNet-50模型的特征 with torch.no_grad(): image_features = resnet_model(image_inputs['pixel_values'].to(device)) # 将特征输入到注意力模块,并计算输出 attended_text_features, attended_image_features = attention_module(text_features, image_features) # 计算损失并进行反向传播 loss = loss_fn(attended_text_features, attended_image_features) optimizer.zero_grad() loss.backward() optimizer.step() ``` 请注意,这只是一个示例代码,具体的实现方式和参数设置可能需要根据你的具体需求进行调整。此外,你还需要根据实际情况调整模型和训练过程中的超参数,以达到最佳性能。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

中杯可乐多加冰

请我喝杯可乐吧,我会多加冰!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值