多模态视觉语言工业质检 InCTRL 模型论文解读


《Toward Generalist Anomaly Detection via In-context Residual Learning with Few-shot Sample Prompts》

准备工作

论文地址:https://arxiv.org/pdf/2403.06495.pdf
项目地址:mala-lab/InCTRL:
语雀文档:https://www.yuque.com/yingmuhuadao-4o05h/edit5y/eg1cg43b1adwc3f2

一、论文通读

1.1 摘要

本文探讨了通用异常检测(GAD)问题,旨在训练一个单一的检测模型,该模型无需对目标数据进行任何进一步的训练,即可在不同应用领域的各种数据集中泛化检测异常情况。最近的一些研究表明,像 CLIP 这样的大型预训练视觉语言模型(VLM)在检测各种数据集中的工业缺陷方面具有很强的泛化能力,但它们的方法在很大程度上依赖于手工制作的缺陷文本提示,因此很难泛化到其他应用中的异常情况,例如医疗图像异常或自然图像中的语义异常。**在这项工作中,我们建议使用少量正常图像作为样本提示,在不同的数据集上即时训练 GAD 模型。**为此,我们引入了一种新颖的方法来学习 GAD 的非语境残差学习模型,称为 InCTRL。该模型在一个辅助数据集上进行训练,根据对查询图像和少量正常样本提示之间残差的整体评估来区分异常和正常样本。无论采用哪种数据集,根据异常的定义,异常样本的残差都会大于正常样本,从而使 InCTRL 无需进一步训练即可在不同领域通用。我们在九个 AD 数据集上进行了综合实验,以建立一个 GAD 基准,该基准涵盖了工业缺陷异常、医疗异常和语义异常的单类和多类检测。代码见 https:// github.com/mala-lab/InCTRL。

1.2 介绍

异常检测(AD)是一项重要的计算机视觉任务,其目的是检测与数据集中大多数样本存在重大偏差的样本,在工业检测、医学成像分析和科学发现等现实生活中有着广泛的应用。[12, 37]. 目前的 AD 范式主要是在每个目标数据集的训练数据(如一组无异常样本)上单独建立一个模型。
虽然这些方法在各种反向干扰基准上都显示出了不俗的检测性能,但它们需要大量的训练数据,并需要对每个数据集的检测模型进行熟练的训练。因此,在一些应用场景中,由于数据隐私问题,在训练模型时使用这些数据所产生的数据隐私问题),或者在部署新应用时无法获得大规模训练数据,因此不允许在目标数据集上进行训练,这些方法就变得不可行了。
为了应对这些挑战,本文探讨了通用异常检测(GAD)模型的学习问题,旨在训练一个单一的检测模型,该模型可以在不对目标数据进行任何训练的情况下,在不同应用领域的不同数据集中泛化检测异常情况。
近年来,通过在网络规模的图像-文本数据上进行预训练,大型视觉语言模型(VLM)(如 CLIP)表现出了卓越的泛化能力,无需在目标数据上进行任何微调或调整,即可在不同数据集上实现准确的视觉识别。更重要的是,最近的一些研究(如 WinCLIP)表明,这些 VLM 也能在不同的缺陷检测数据集上实现显著的泛化。然而,这些模型的一个显著局限是它们依赖于大量人工制作的特定于缺陷的提示。这种依赖性限制了它们的适用性,使得将它们扩展到其他数据领域的异常检测具有挑战性,例如医疗图像异常或单类或多类设置中的语义异常。
为了解决这个问题,我们建议训练一个 GAD 模型,旨在利用任何目标数据集中的少量正常图像作为样本提示,以支持即时 GAD,如**图 1(上)所示。在实际应用中,通常很容易获得少量的正常图像,因此采用了少量图像的设置。此外,这些少量拍摄的样本并不用于模型训练/调整;它们只是用作样本提示,以便在推理过程中对测试图像进行异常评分。如图 1(下)**所示,这种方法与当前的少量样本 AD 方法有着本质区别,后者使用这些目标样本及其广泛的增强版本来训练检测模型,这可能会导致目标数据集的过度拟合,并且无法泛化到其他数据集。随后,我们介绍了一种 GAD 方法,这是第一种基于 CLIP 学习上下文残差学习模型的方法,称为 InCTRL。它通过学习识别查询图像与一组来自辅助数据的少量正常图像之间的残差/差异来训练 GAD 模型,从而从正常样本中分辨出异常图像。少量正常图像,即上下文中的样本提示,是正常模式的原型。与这些正常模式的特征相比,根据异常的定义,在不同领域的数据集中,异常的残差通常会大于正常样本的残差,因此学习到的上下文残差模型可以泛化到检测不同领域的各种异常类型。
图1 上图:InCTRL 的图示:使用少张正常图像作为样本提示的一换一模型。下图 InCTRL 和其他几幅图像方法在三个不同应用数据集上的 AUROC 曲线,无需对目标数据进行任何训练。
为了更好地捕捉残差,InCTRL 在图像和补丁两个层面对上下文残差进行建模,从而深入理解什么是异常。此外,我们的上下文残差学习还能将正常/异常文本提示引导的先验知识无缝纳入检测模型,为从文本图像对齐的语义空间进行检测提供额外的优势。
因此,我们的主要贡献如下。
- 我们引入了 GAD 任务,以评估 AD 方法在各种场景下识别异常的泛化能力,而无需在目标数据集上进行训练/调整。据我们所知,这是第一项专门针对异常检测通用方法的研究,涵盖了工业缺陷、医疗异常和语义异常。
- 然后,我们提出了一种用于 GAD 的上下文残差学习框架,称为 InCTRL。InCTRL 在辅助数据上进行了优化,以实现 "一模多用 "的目标,即在不同数据集上使用一个 AD 模型,而无需在目标数据上进行任何训练。
- 我们在九个不同的 AD 数据集上进行了综合实验,建立了一个 GAD 基准,该基准囊括了三类流行的 AD 任务,包括工业缺陷异常检测、医学图像异常检测以及单类和多类设置下的语义异常检测。我们的研究结果表明,InCTRL 明显优于最先进的竞争方法。

1.3 相关工作

1.3.1 异常检测

异常检测。由于异常数据稀缺,现有的异常检测方法通常依赖于无监督学习。目前已推出了许多方法。单类分类方法:专注于用支持向量紧凑描述正常数据。基于重建的方法:训练模型来重建正常图像,通过较高的重建误差来识别异常。基于距离的方法:根据存储的训练数据中测试图像嵌入和正常参考嵌入之间的距离来确定异常。知识提炼方法:侧重于从预训练模型中提炼正常模式,并根据提炼特征与原始特征之间的差异检测异常。
上述方法都是为适应 AD 目标数据集而设计的,即一个数据集一个模型。我们的目标是一个模型适用于所有设置。一个相关的研究方向是解决领域或分布转移下的 AD 问题,但它们一般都假定源数据和目标数据具有很大的领域相关性。此外,还有一些利用 VLMs 进行 AD 的并行研究,但它们采用了与我们不同的设置,例如弱监督 AD 或零样本 AD。
少量异常检测(FSAD)。FSAD 旨在仅使用目标数据集中数量有限的正常样本来识别异常。传统的 FSAD 研究侧重于对这些少数正常样本的正态分布建模,以检测异常。然而,这些方法通常无法推广到新的领域,因为它们通常需要使用目标数据集的正常数据进行重新训练或微调。基于距离的方法,如 SPADE 、PaDiM 和 PatchCore,提出了解决这一问题的方案,即充分利用现有的预训练的少量样本表示,无需训练即可计算基于距离的异常分数。WinCLIP率先将大型视觉语言模型(VLM)应用于零样本 和少样本异常检测任务,通过多尺度窗口移动和文本提示 CLIP 来处理图像。在不对 CLIP 进行调整以适应 AD 任务的情况下,WinCLIP 利用其手工制作的文本提示在缺陷数据集上获得了令人印象深刻的零次检测性能,但当文本提示无法捕捉所需的异常语义时,WinCLIP 就无法很好地发挥作用,从而使其难以很好地推广到各种异常检测任务中。

1.3.2 上下文学习

上下文学习是一种有助于提高自然语言处理(NLP)中大型语言模型(LLM)性能的创新方法,它利用最小的上下文提示来使 LLM 有效地适应新任务。最近,一些研究尝试将上下文学习应用于视觉任务,利用语言或专门设计的离散标记作为任务提示,将视觉问题转换为 NLP 问题。另一方面,Amir 等 将一系列视觉任务视为网格内绘画问题,从而引入了一种新颖的上下文视觉提示方法。然而,这些方法更注重任务级的泛化,因此不适用于更注重实例级差异的 AD 任务。我们的工作是重新设计 GAD 的上下文学习。我们将图像提示重新定义为特定数据集的正常模式,而不是特定任务的指令。通过捕捉查询图像和少量正常提示之间的上下文残差,我们的模型可以获得对各种异常情况的连贯理解,从而为 GAD 带来显著的通用检测性能。

1.4 InCTRL:上下文中的残差学习

1.4.1 InCTRL 方法概述

我们的 InCTRL 方法旨在有效地模拟查询图像与作为样本提示的一组少量正常图像之间的上下文残差,利用 CLIP 的泛化能力来检测来自不同应用领域的异常残差。CLIP 是一种 VLM,由文本编码器 f t ( ⋅ ) f_t(·) ft() 和视觉编码器 f v ( ⋅ ) f_v(·) fv() 组成,通过在网络规模的文本图像数据上进行预训练,使这些编码器的图像和文本表示方法完全一致。InCTRL 通过图像编码器中的上下文残差学习使用辅助数据 D t r a i n \mathcal{D}_{train} Dtrain 进行优化,并通过文本编码器中文本提示引导的先验知识来增强学习效果。
图2 InCTRL 训练概述。首先,它使用查询图像和从辅助训练数据中随机抽取的少量正常样本提示来模拟上下文学习场景。然后,它执行多层补丁级和图像级残差学习,以捕捉查询图像和正常提示之间的局部和全局残差。最后,这些残差信息与文本编码器提供的文本提示先验知识相结合,用于整体异常评分学习。
具体来说,如图 2 所示,我们首先模拟一个上下文学习示例,该示例包含一张查询图像 x x x 和一组很少的正常样本提示 P ′ \mathcal{P}^{\prime} P两者都是从辅助数据训练中随机抽样的然后,我们通过视觉编码器执行多层补丁级和图像级残差学习,分别捕捉查询图像和少量正常样本提示之间的局部和全局差异(第 1.4.2 和 1.4.3 节)。此外,我们的模型还可以根据文本提示嵌入和查询图像之间的相似性,将文本编码器的先验知识无缝纳入正常和异常文本提示(第 1.4.4 节)。InCTRL 的训练是优化视觉编码器上的几个投影/适配层,以便在冻结两个编码器原始参数的情况下,在 D t r a i n \mathcal{D}_{train} Dtrain中为异常样本学习到比正常样本更大的异常得分;在推理过程中,测试图像连同目标数据集中的少量正常图像提示和文本提示,将通过我们经过调整的基于 CLIP 的 GAD 网络进行推理,其输出是测试图像的异常得分(第 1.4.5 节)。下面我们将详细介绍这些模块。

1.4.2 多层补丁级残差学习

为了有效捕捉查询图像与正常图像提示之间的细粒度上下文残差,我们在 InCTRL 中引入了多层补丁级残差学习组件。通常,CLIP 视觉编码器由一系列块层组成。从底层到顶层,视觉编码器会逐渐学习不同抽象层次的视觉模式。因此,该组件旨在从视觉编码器内的多层块获得的补丁级标记嵌入中建立补丁级上下文残差模型。
具体来说,假设视觉编码器由 n 个块组成,对于一组给定的少样本正常提示 P ′ \mathcal{P}^{\prime} P和一幅训练查询图像 x x x。我们提取一系列补丁标记嵌入图 { T x l } l = 1 n \{T_x^l\}^n_{l=1} {Txl}l=1n { T x ′ l } l = 1 n \{T_{x'}^l\}^n_{l=1} {Txl}l=1n 其中, T ( ⋅ ) l ∈ R h × w × d T_{(\cdot)}^{l} \in \mathbb{R}^{h\times w\times d} T()lRh×w×d x ′ ∈ P ′ x^{\prime}\in\mathcal{P}^{\prime} xP,h、w 和 d 分别为特征图 T 的高度、宽度和维度。在每层 l l l**,补丁级上下文残差由 P ′ \mathcal{P}^{\prime} P 中所有图像提示中查询标记和图像提示标记的嵌入之间的距离来捕捉。**形式上,对于查询图像 x x x,其在第 l l l层的多层补丁级上下文残差由残差图 M x l ∈ R h × w \mathrm{M}_{x}^{l}\in\mathbb{R}^{h\times w} MxlRh×w来建模,其中 x x x的每个补丁的残差值根据其补丁嵌入和 P ′ \mathcal{P}^{\prime} P 中所有图像的最近补丁嵌入计算得出,即:
M x l ( i , j ) = 1 − ⟨ T x l ( i , j ) , h ( T x l ( i , j ) ∣ P ′ ) ⟩ \mathbf{M}_{x}^{l}(i,j)=1-\langle T_{x}^{l}(i,j),h(T_{x}^{l}(i,j)|\mathcal{P}^{\prime})\rangle Mxl(i,j)=1Txl(i,j),h(Txl(i,j)P)⟩
其中 h ( T x l ( i , j ) ∣ P ′ ) h(T_{x}^{l}(i,j)|\mathcal{P}^{\prime}) h(Txl(i,j)P)返回 P ′ \mathcal{P}^{\prime} P中所有图像补丁中与 T x l ( i , j ) T_{x}^{l}(i,j) Txl(i,j)最相似的补丁标记的嵌入, ⟨ ⋅ ⟩ ⟨·⟩ 是余弦相似度函数。最终的补丁级残差图 M x ∈ R h × w \mathrm{M}_{x}\in\mathbb{R}^{h\times w} MxRh×w 是 n 层残差图的平均值:
M x = 1 n ∑ l = 1 n M x l \mathbf{M}_x=\frac{1}{n}\sum_{l=1}^n\mathbf{M}_x^l Mx=n1l=1nMxl
M x \mathbf{M}_x Mx中的每个残差值都类似于 P ′ \mathcal{P}^{\prime} P中查询补丁与图像补丁集的最近邻距异常得分。正如之前的研究所示,这种基于距离的异常分数能有效区分异常和正常样本。因此,得到的残差图 M x \mathbf{M}_x Mx为 InCTRL 的后续异常评分学习提供了多层分辨率的集体异常判别能力特征集。

1.4.3 图像级残差学习

除了局部补丁级残差的判别能力,图像级的全局判别信息也很重要,是补丁级特征的补充知识。
因此,我们引入了图像级残差学习组件,以捕捉 x x x P ′ \mathcal{P}^{\prime} P之间更高层次的差异。直观地说,视觉编码器最后一个区块的类标记嵌入被用作特征输入,因为视觉编码器对信息进行了自下而上的抽象,所以它能捕捉到最多的图像级判别信息。不过,值得注意的是,CLIP 最初是为分类任务而设计的,侧重于景物中物体的语义,这与异常检测任务并不一致,因为正常和异常样本往往来自同一类物体。为了解决这个问题,我们加入了一个以 Θ ψ \Theta_{\psi} Θψ为参数的适配层 ψ ( ⋅ ) \psi(\cdot) ψ(),以进一步适配异常检测的图像表征,从而根据适配的图像特征学习图像级残差。此外,我们使用少样本提示的原型特征而非单个样本的特征来学习上下文残差,因为它们有助于捕捉正常模式中更具代表性的特征。
具体来说,假设 f v ( x ) ∈ R d ′ f_v(x)\in\mathbb{R}^{d^{\prime}} fv(x)Rd是输入 x x x在视觉编码器中的类标记嵌入,我们首先计算 P ′ \mathcal{P}^{\prime} P中图像提示的特征图原型:
I p = 1 K ∑ x k ′ ∈ P ′ ψ ( f v ( x k ′ ) ; Θ ψ ) \mathbf{I}_{p}=\frac{1}{K}\sum_{x_{k}^{\prime}\in\mathcal{P}^{\prime}}\psi(f_{v}\left(x_{k}^{\prime});\Theta_{\psi}\right) Ip=K1xkPψ(fv(xk);Θψ)
其中 I p ∈ R d ′ \mathbf{I}_p\in\mathbb{R}^{d^{\prime}} IpRd。设 I x = ψ ( f v ( x ) ; Θ ψ ) \mathbf{I}_x = \psi(f_v(x);\Theta_\psi) Ix=ψ(fv(x);Θψ)为查询图像 x x x的适配特征, x x x的上下文图像级残差特征 F x \mathbf{F}_x Fx是通过对两个特征图进行元素减法得到的:
F x = I x ⊖ I p \mathbf{F}_x=\mathbf{I}_x\ominus\mathbf{I}_p Fx=IxIp
其中 ⊖ 表示元素减法。随后,这些上下文中的残差特征被输入图像级异常分类学习器 η : F x → R \eta:\mathbf{F}_{x}\to\mathbb{R} η:FxR,参数为 Θ η \Theta_{\eta} Θη ,该学习器通过二元分类损失进行优化:
L I R L = 1 N ∑ x ∈ X t r a i n L b ( η ( F x ; Θ η ) , y x ) \mathcal{L}_{IRL}=\frac1N\sum_{x\in X_{train}}\mathcal{L}_b(\eta(\mathbf{F}_x;\Theta_\eta),y_x) LIRL=N1xXtrainLb(η(Fx;Θη),yx)
其中 L b \mathcal{L}_b Lb是二元分类损失。我们的模型默认使用 Focal loss

1.4.4 融合基于文本提示的先验知识

上述两个组件的重点是基于视觉编码器的残差学习。InCTRL 还可以轻松纳入由 CLIP 文本编码器提供的关于正常和异常的文本提示先验知识。这有助于 InCTRL 利用隐藏在 CLIP 预训练的图像文本对齐嵌入空间中的正常和异常语义来处理 GAD。受此启发,InCTRL 利用文本编码器提取文本提示引导的判别特征。由于 WinCLIP 中设计的文本提示显示出显著的检测性能,InCTRL 采用了相同的文本提示模板及其集合策略,包括状态和模板级文本提示。在状态级,使用通用文本描述来区分正常和异常对象,而模板级则提供了一个专门用于异常检测的特定提示列表。
值得注意的是,与 WinCLIP 使用这些文本提示直接计算异常得分不同,InCTRL 利用它们来提取文本提示引导的特征,以补充通过视觉编码器获得的补丁级和图像级残差特征。具体来说,假设 P t n \mathcal{P}_{t}^{n} Ptn是正常类的文本提示集,我们使用文本提示嵌入的原型来提供正常文本提示的代表性嵌入 F n = 1 ∣ P t n ∣ ∑ p i ∈ P t n f t ( p i ) \mathbf{F}_{n} = \frac{1}{|\mathcal{P}_{t}^{n}|}\sum_{p_{i}\in\mathcal{P}_{t}^{n}}f_{t}(p_{i}) Fn=Ptn1piPtnft(pi)其中 p i ∈ R d ′ p_{i}\in\mathcal{R}^{d^{\prime}} piRd; 我们可以通过 F a = 1 ∣ P t a ∣ ∑ p j ∈ P t a f t ( p j ) \mathbf{F}_{a} = \frac{1}{|\mathcal{P}_{t}^{a}|}\sum_{p_{j}\in\mathcal{P}_{t}^{a}}f_{t}(p_{j}) Fa=Pta1pjPtaft(pj)得到异常文本提示集 P t a \mathcal{P}_{t}^{a} Pta的原型嵌入。然后,InCTRL 会根据查询图像 x x x与文本提示的两个原型之间的相似性提取一个面向 AD 的判别特征:
s a ( x ) = exp ⁡ ( F a ⊺ f v ( x ) ) exp ⁡ ( F n ⊺ f v ( x ) ) + exp ⁡ ( F a ⊺ f v ( x ) ) s_{a}(x)=\frac{\exp(\mathbf{F}_a^\intercal f_v(x))}{\exp(\mathbf{F}_n^\intercal f_v(x))+\exp(\mathbf{F}_a^\intercal f_v(x))} sa(x)=exp(Fnfv(x))+exp(Fafv(x))exp(Fafv(x))
其中 [ ⋅ ] ⊺ [\cdot]^{\intercal} []表示转置运算, s a ( x ) s_{a}(x) sa(x)是输入 x x x被归类为异常的概率。

1.4.5 推理和训练

上下文残差学习。在训练过程中,InCTRL 执行整体残差学习,综合补丁级和图像级残差信息,并通过文本提示引导的特征进行增强。查询图像 x x x的整体上下文残差图定义为
M x + = M x ⊕ s i ( x ) ⊕ s a ( x ) \mathrm{M}_x^+=\mathrm{M}_x\oplus s_i(x)\oplus s_a(x) Mx+=Mxsi(x)sa(x)
其中 s i ( x ) = η ( F x ; Θ η ) s_i(x)=\eta(\mathbf{F}_x;\Theta_\eta) si(x)=η(Fx;Θη)是基于图像级残差图 F x \mathbf{F}_{x} Fx的异常得分, ⊕ \oplus 表示元素相加。然后,InCTRL 基于 M x + \mathrm{M}_x^+ Mx+设计了一个整体异常评分函数 ϕ \phi ϕ,参数为 Θ ϕ \Theta_\phi Θϕ,并将最终异常评分定义为
s ( x ) = ϕ ( M x + ; Θ ϕ ) + α s p ( x ) s(x)=\phi(\mathbf{M}_x^+;\Theta_\phi)+\alpha s_p(x) s(x)=ϕ(Mx+;Θϕ)+αsp(x)
其中 ϕ ( M x + ; Θ ϕ ) \phi(\mathbf{M}_x^+;\Theta_\phi) ϕ(Mx+;Θϕ)使用补丁级、图像级和文本提示引导的特征进行整体异常评分,而 s p ( x ) = max ⁡ ( M x ) s_p(x) = \max(\mathbf{M}_x) sp(x)=max(Mx)是图像补丁级基于最大残差分的细粒度异常评分。
s p ( x ) s_p(x) sp(x)被添加到上式中,因为这种补丁级异常得分对于检测局部异常区域至关重要,而基于 ϕ \phi ϕ的整体异常得分往往会忽略这些区域。 α \alpha α是一个超参数,用于调节补丁级残差得分的贡献。最后,我们使用 X t r a i n X_{train} Xtrain优化最终的异常得分:
L h = 1 N ∑ x ∈ X t r a i n L b ( s ( x ) , y x ) \mathcal{L}_{h}=\frac{1}{N}\sum_{x\in X_{train}}\mathcal{L}_{b}(s(x),y_{x}) Lh=N1xXtrainLb(s(x),yx)
因此,完整的 InCTRL 模型是通过最小化整体损失来优化的:
L I n C T R L = L I R L + L h \mathcal{L}_{InCTRL}=\mathcal{L}_{IRL}+\mathcal{L}_h LInCTRL=LIRL+Lh
推理。在推理过程中,对于给定的测试图像 xt 和来自目标数据集的 K 张正常图像提示集 P \mathcal{P} P,它们将通过视觉编码器和适配器层进行前馈,从而得到 M x t \mathbf{M}_{x_t} Mxt s i ( x t ) s_i(x_t) si(xt)。训练过程中使用的文本提示集用于获得 s a ( x t ) s_a(x_t) sa(xt)。最后,我们通过上述公式 s ( x ) s(x) s(x)得出 x t x_t xt的最终异常得分。

二、代码通读

通过介绍各个模块,由浅入深理解源代码

2.1 InCTRL 类

核心代码,介绍了 InCTRL 的主要框架:Multi-layer Patch-level Residual Learning、Image-level Residual Learning 和推理训练

class InCTRL(nn.Module):
    def __init__(
            self,
            args, # 模型配置的参数
            embed_dim: int, # 嵌入维度
            vision_cfg: CLIPVisionCfg, # 视觉编码器的配置
            text_cfg: CLIPTextCfg, # 文本编码器的配置
            quick_gelu: bool = False, # 是否使用快速GRLU激活函数
            cast_dtype: Optional[torch.dtype] = None, # 数据类型,用于混合精度训练
            output_dict: bool = False, # 是否以字典形式输出模型的输出
    ):
        super().__init__()
        self.output_dict = output_dict
        # 构建视觉编码器
        self.visual = _build_vision_tower_Mul(embed_dim, vision_cfg, quick_gelu, cast_dtype)
        
        # 构建视觉编码器
        text = _build_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype)
        self.transformer = text.transformer # 文本编码器的Transformer部分
        self.context_length = text.context_length
        self.vocab_size = text.vocab_size
        self.token_embedding = text.token_embedding
        self.positional_embedding = text.positional_embedding
        self.ln_final = text.ln_final # 激活层
        self.text_projection = text.text_projection
        self.register_buffer('attn_mask', text.attn_mask, persistent=False)

        self.adapter = Adapter(640, 4) # 适配器,用于调整特征维度
        self.diff_head = TransformerBasicHead(225, 1) # 用于计算差异分数的基本Transformer头
        self.diff_head_ref = TransformerBasicHead(640, 1) 

        for p in self.visual.parameters():
            p.requires_grad = False

        for p in text.parameters():
            p.requires_grad = False

    # 对图像进行编码,生成特征表示
    def encode_image(self, image, out_layers: list = [7, 9, 11], normalize: bool = False):
        features = self.visual.forward(image, out_layers)
        return F.normalize(features, dim=-1) if normalize else features # 标准化

    # 对文本进行编码,生成特征表示
    def encode_text(self, text, normalize: bool = False):
        cast_dtype = self.transformer.get_cast_dtype()
        x = self.token_embedding(text).to(cast_dtype)  # [batch_size, n_ctx, d_model]

        x = x + self.positional_embedding.to(cast_dtype) # 位置嵌入
        x = x.permute(1, 0, 2)  # NLD -> LND
        x = self.transformer(x, attn_mask=self.attn_mask)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_final(x)  # [batch_size, n_ctx, transformer.width]
        # 从 eot 嵌入中提取特征(eot_token 是每个序列中的最高数字)
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
        return F.normalize(x, dim=-1) if normalize else x

    def forward(self, tokenizer, image: Optional[torch.Tensor] = None, text: Optional[torch.Tensor] = None, normal_list = None):
        # 根据是否提供了 normal_list,处理正常样本的图像
        if normal_list == None: 
            img = image[0].cuda(non_blocking=True)
            normal_image = image[1:]
            normal_image = torch.stack(normal_image) # 将多个图像组合成一个batch
            shot, b, _, _, _ = normal_image.shape
            normal_image = normal_image.reshape(-1, 3, 240, 240).cuda(non_blocking=True)
        else:
            img = image[0].cuda(non_blocking=True)
            normal_image = normal_list
            normal_image = torch.stack(normal_image)
            normal_image = normal_image.unsqueeze(1)
            b = len(img)
            normal_image = normal_image.repeat(1, b, 1, 1, 1)
            shot, _, _, _, _ = normal_image.shape
            normal_image = normal_image.reshape(-1, 3, 240, 240).cuda(non_blocking=True)

        token, Fp_list, Fp = self.encode_image(img, normalize=False)
        token_n, Fp_list_n, Fp_n = self.encode_image(normal_image, normalize=False)

        Fp_list = torch.stack(Fp_list)
        Fp_list_n = torch.stack(Fp_list_n)

        Fp_list = Fp_list[:, :, 1:, :]
        Fp_list_n = Fp_list_n[:, :, 1:, :]

        Fp_list = Fp_list.reshape(b, 3, 225, -1)
        Fp_list_n = Fp_list_n.reshape(b, 3, 225 * shot, -1)

        token_n = token_n.reshape(b, shot, -1)

        token_ad = self.adapter.forward(token)
        token_n = self.adapter.forward(token_n)
        token_n = torch.mean(token_n, dim=1)
        token_ref = token_n - token_ad

        
        text_score = []
        max_diff_score = [] # 保存 当前图像块与正常图像块之间最大的差异程度
        patch_ref_map = []
        for i in range(len(token)):
            Fp = Fp_list[i, :, :, :]
            Fp_n = Fp_list_n[i, :, :, :]
            # 实现 多层补丁级残差学习 模块
            Fp_map = list()
            for j in range(len(Fp)):
                tmp_x = Fp[j, :, :]
                tmp_n = Fp_n[j, :, :]
                am_fp = list()
                for k in range(len(tmp_x)):
                    tmp = tmp_x[k]
                    tmp = tmp.unsqueeze(0)
                    tmp_n = tmp_n / tmp_n.norm(dim=-1, keepdim=True)
                    tmp = tmp / tmp.norm(dim=-1, keepdim=True)
                    s = (0.5 * (1 - (tmp @ tmp_n.T))).min(dim=1).values # 计算相似度
                    am_fp.append(s)
                am_fp = torch.stack(am_fp)
                Fp_map.append(am_fp)
            Fp_map = torch.stack(Fp_map)
            Fp_map = torch.mean(Fp_map.squeeze(2), dim=0)
            patch_ref_map.append(Fp_map)
            score = Fp_map.max(dim=0).values
            max_diff_score.append(score)

            # 计算图像特征与文本特征之间的相似度
            image_feature = token[i]
            image_feature = image_feature.unsqueeze(0)
            image_feature = image_feature / image_feature.norm(dim=-1, keepdim=True)

            obj_type = text[i] # 获取与当前图像对应的文本描述
            normal_texts, anomaly_texts = get_texts(obj_type.replace('_', " "))
            pos_features = tokenizer(normal_texts).cuda() # 使用tokenizer将正常和异常的文本转换为模型可处理的格式,
            neg_features = tokenizer(anomaly_texts).cuda() # 并使用CUDA进行加速
            pos_features = self.encode_text(pos_features) # 将文本编码为特征表示
            neg_features = self.encode_text(neg_features)
            pos_features = pos_features / pos_features.norm(dim=-1, keepdim=True) # 归一化
            neg_features = neg_features / neg_features.norm(dim=-1, keepdim=True)
            pos_features = torch.mean(pos_features, dim=0, keepdim=True)
            neg_features = torch.mean(neg_features, dim=0, keepdim=True)
            pos_features = pos_features / pos_features.norm(dim=-1, keepdim=True)
            neg_features = neg_features / neg_features.norm(dim=-1, keepdim=True)
            text_features = torch.cat([pos_features, neg_features], dim=0) #将正常和异常的特征拼接在一起
            score = (100 * image_feature @ text_features.T).softmax(dim=-1) # 图像特征与正负样本文本特征的匹配度
            tmp = score[0, 1]
            text_score.append(tmp)

        # 推理和训练 部分
        text_score = torch.stack(text_score).unsqueeze(1) # Sa(x)
        # 之前计算的参考图像特征token_ref进行前向传播,得到图像的异常分数 Si(x)
        img_ref_score = self.diff_head_ref.forward(token_ref)
        patch_ref_map = torch.stack(patch_ref_map) # Mx
        # 将文本分数、参考图像分数和图像块参考映射相加,得到一个综合的特征映射
        holistic_map = text_score + img_ref_score + patch_ref_map # Mx+
        hl_score = self.diff_head.forward(holistic_map) # 整体的异常分数

        hl_score = hl_score.squeeze(1)
        fg_score = torch.stack(max_diff_score)
        final_score = (hl_score + fg_score) / 2 # 将整体分数和最大差异分数取平均

        img_ref_score = img_ref_score.squeeze(1)

        return final_score, img_ref_score # 得到 S(x) Si(x)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值