Abstract
随着光学传感器质量的提高,需要处理大规模图像。特别是,设备捕获超高清 (UHD) 图像和视频的能力对图像处理管道提出了新的要求。在本文中,我们考虑低光图像增强(LLIE)的任务,并引入由 4K 和 8K 分辨率图像组成的大规模数据库。我们进行系统的基准测试研究并提供当前 LLIE 算法的比较。作为第二个贡献,我们引入了 LLFormer,一种基于 Transformer 的低光增强方法。 LLFormer的核心组件是基于轴的多头自注意力和跨层注意力融合块,显着降低了线性复杂度。对新数据集和现有公共数据集的大量实验表明,LLFormer 的性能优于最先进的方法。我们还表明,采用在我们的基准上训练的现有 LLIE 方法作为预处理步骤可以显着提高下游任务的性能,例如低光条件下的人脸检测。源代码和预训练模型可在 https://github.com/TaoWangzj/LLFormer 获取。
Introduction
在弱光条件下拍摄的图像通常会出现明显的退化,例如可见度差、对比度低和噪点水平高。为了减轻这些影响,人们提出了许多低光图像增强(LLIE)方法,将给定的低光图像转换为具有适当亮度的高质量图像。传统的LLIE方法主要基于图像先验或其他任务的物理模型,例如基于直方图均衡的方法(Kim 1997;Stark 2000)、基于视网膜的方法(Kimmel et al. 2003;Wang et al. 2014)和去雾基于方法(Dong et al. 2011;Zhang et al. 2012)。最近,许多基于学习的 LLIE 方法被引入,利用大规模合成数据集,在性能和速度方面取得了显着的改进(Wei et al. 2018;Guo et al. 2020;Lim and Kim 2020;Jiang et al. al. 2021;Li、Guo 和 Loy,2021;Liu 等人,2021b)。
大多数现有数据集,例如 LOL (Wei et al. 2018) 和 SID (Chen et al. 2018),由较低分辨率的图像(1K 或更低)组成。因此,在这些数据集上训练的 LLIE 方法自然受限于低分辨率图像。现代移动设备上的传感器能够捕获 4K 或 8K 分辨率的图像,因此需要专门用于处理超高清 (UHD) 图像的算法。现有的 LLIE 方法很难同时协调超高清图像的推理效率和视觉增强。在本文中,我们重点关注超高清低光图像增强(UHD-LLIE)的任务。我们首先构建一个包含低光条件下超高清图像(UHD-LOL)的大规模基准数据集,以探索和评估图像增强算法。 UHD-LOL 包括两个子集:UHD-LOL4K 和 UHD-LOL8K,分别包含 4K 和 8K 分辨率图像。 UHD-LOL4K 子集包含 8, 099 个图像对,其中 5, 999 个用于训练,2, 100 个用于测试。 UHD-LOL8K 的子集包括 2, 966 个图像对,其中 2, 029 个用于训练,937 个用于测试。 4K 和 8K 低光图像示例如图 1 所示。
使用该数据集,我们进行了广泛的基准测试研究,以比较现有的 LLIE 方法并强调 UHD 设置中的一些缺点。我们针对 UHD-LLIE 任务提出了一种基于 Transformer 的新颖方法,名为 Low-Light Transformerbased Network (LLFormer)。 LLFormer 由两个基本单元组成,一个高效的 Axisbased Transformer Block 和一个 Cross-layer Attention Fusion Block。在 Axis-based Transformer 模块中,基于轴的自注意力单元在通道维度上的特征的高度和宽度轴上执行自注意力机制,以较低的计算复杂度捕获非局部自相似性和远程依赖性。此外,在基于轴的自注意力之后,我们设计了一种新颖的双门控前馈网络,它采用双门控机制来关注有用的特征。跨层注意力融合块学习不同层中特征的注意力权重,并自适应地将特征与学习到的权重融合以改进特征表示。 LLFormer采用分层结构,极大缓解了UHD-LLIE任务的计算瓶颈。
总而言之,我们的工作贡献如下:
1. 我们构建了4K和8K超高清图像的基准数据集UHD-LOL,以探索和评估图像增强算法。据我们所知,这是文献中第一个大规模超高清低光图像增强数据集。
2. 基于UHD-LOL,我们对现有的LLIE算法进行基准测试,以展示这些方法的性能和局限性,提供新的见解。
3. 我们为 UHD-LLIE 任务提出了一种新颖的变压器模型 LLFormer。在定量和定性方面,LLFormer 在公共 LOL 和 MIT-Adobe FiveK 数据集以及我们的 UHD-LOL 基准上都实现了最先进的性能。
Related Work
Low-light Image Datasets.
随着数据驱动方法的进步(Zhang et al. 2022, 2021b),许多用于低光图像增强的数据集被提出。在(Vonikakis 等人,2013 年)中,Vonikakis 等人。引入了 LLIE 数据集,其中包含从 15 个不同场景收集的 225 张图像。在每个场景中,他们拍摄了 15 张图像,其中 9 张是不同强度的均匀照明条件下的图像,6 张是非均匀照明条件下的图像。 Lore 等人通过应用随机伽玛校正和高斯噪声。 (Lore、Akintayo 和 Sarkar 2017)从 169 张图像中合成了 422、500 个低光图像块。 (Shen 等人,2017)创建了一个包含 10, 000 个图像对的 LLIE 数据集,其中 8, 000 个用于训练,2, 000 个用于测试。 (Chen 等人,2018)构建了“黑暗中看到”(SID)数据集。它包含 5, 094 张短曝光低光原始图像及其相应的长曝光图像。 (Cai、Gu 和Zhang 2018)使用多重曝光图像融合(MEF)或高动态范围(HDR)算法从 589 个图像序列合成了 SICE 数据集。 (Wei et al. 2018) 创建了低光 (LOL) 数据集,其中包含 485 个用于训练的图像对和 15 个用于测试的图像对。 Liu 等人基于 LOL 数据集。 (Liu et al. 2021a) 创建了 VE-LOL-L 用于训练和评估 LLIE 方法,其中包括 2、100 个用于训练的图像和 400 个用于评估的图像。 MIT-Adobe FiveK(Bychkovsky 等人,2011 年)由 5, 000 张在不同照明条件下拍摄的各种室内和室外场景的图像组成。在本文中,我们介绍了 4K 和 8K 图像的超高清低光图像增强基准,其大小足以用于训练深度模型和比较现有方法。
Low-light Image Enhancement Methods.
LLIE 旨在从低光条件下拍摄的曝光不足的图像中恢复图像(Li et al. 2022)。在过去的几十年里,针对 LLIE 任务提出了许多方法,这些方法大致可以分为两类:基于非学习的方法和数据驱动的方法。传统方法主要包括基于直方图的方法(Lee, Lee, and Kim 2012; Xu et al. 2013; Celik 2014)、基于Retinex的方法(Kimmel et al. 2003; Wang et al. 2014)和基于去雾的方法(Zhang 等人,2012 年;Li 等人,2015 年)。尽管传统方法获得了合理的结果,但这些方法的增强图像通常会出现颜色失真或过度增强等伪影。数据驱动方法(Wei et al. 2018;Zhang et al. 2020;Lim and Kim 2020;Ni et al. 2020;Zeng et al. 2020;Wang et al. 2022a, 2021;Yang et al. 2022)成功应用于LLIE任务。例如,(Wei et al. 2018)中的 RetinexNet 将 Retinex 理论和深度 CNN 结合在统一的端到端学习框架中。 DSLR(Lim 和 Kim 2020)应用拉普拉斯金字塔方案来增强编码器-解码器架构的多个流中的全局和局部细节。最近,基于 Transformer 的数据驱动方法已应用于低级任务:Uformer(Wang et al. 2022b)使用修改后的 Swin Transformer 块(Liu et al. 2021c)构建 U 形网络,在图像方面表现出良好的性能恢复。 Restormer(Zamir 等人,2022)引入了对 Transformer 块的修改,以改进图像恢复的特征聚合。虽然 Transformer 在许多任务中表现良好,但其弱光图像增强的潜力仍未被开发。在这项工作中,我们专注于设计用于超高清低光图像增强的变压器。
Benchmark and Methodology
Benchmark Dataset
我们创建了一个名为 UHDLOL 的新的大规模 UHD-LLIE 数据集,以对现有 LLIE 方法的性能进行基准测试并探索 UHD-LLIE 问题。 UHD-LOL分别由3840×2160分辨率的4K图像和7680×4320分辨率的8K图像组成。为了构建这个图像对数据集,我们使用公共数据中的普通光 4K 和 8K 图像 (Zhang et al. 2021a)。这些超高清图像是从网络上爬取并由各种设备捕获的。图像包含室内和室外场景,包括建筑物、街道、人物、动物和自然景观。我们合成了相应的低光图像(Wei et al. 2018),它同时考虑了低光退化过程和自然图像统计。具体来说,我们首先生成三个随机变量 X、Y、Z,均匀分布在 (0, 1) 中。我们使用这些变量来生成 Adobe Lightroom 软件提供的参数。参数包括曝光 (−5+5X2)、高光 (50 min{Y, 0.5} + 75)、阴影 (−100 min{Z, 0.5})、鲜艳度 (−75 + 75X2) 和白色 (16(5) - 5X2))。合成的低光和正常光图像组成了我们的 UHD-LOL,它由两个子集组成:UHD-LOL4K 和 UHD-LOL8K。 UHD-LOL4K 子集包含 8, 099 对 4K 低光/正常光图像。其中,5, 999 对图像用于训练,2, 100 对图像用于测试。 UHD-LOL8K 子集包括 2, 966 对 8K 低光/正常光图像,分为 2, 029 对用于训练,937 对用于测试。示例图像如图 1 所示。
LLFormer Architecture
如图2所示,LLFormer的整体架构是一个分层的编码器-解码器结构。给定低光图像 I ∈ RH×W×3,LLFormer 首先采用 3 × 3 卷积作为投影层来提取浅层特征 F0 ∈ RH×W×C。接下来,F0 被输入到三个连续的 Transformer 块中以提取更深层次的特征。更具体地说,从 Transformer 块输出的中间特征表示为 F1、F2、F3 ∈ RH×W×C。这些特征 F1、F2、F3 通过所提出的跨层注意融合块进行聚合并转换为增强的图像特征 F4。其次,编码器中的四个阶段用于 F4 上的深度特征提取。具体来说,每一级包含一个下采样层和多个变压器块。从顶部到底部,变压器块的数量不断增加。我们使用像素取消洗牌操作(Shi et al. 2016)来缩小空间尺寸并将通道数加倍。因此,编码器第i级的特征可以表示为 Xi ∈ R H 2i ×W 2i ×2iC ,对应于四个级 i = 0, 1, 2, 3 。随后,低分辨率潜在特征 X3 通过包含三个阶段的解码器,并以 X3 作为输入并逐步恢复高分辨率表示。每个阶段都由一个上采样层和多个变压器块组成。解码器第 i 阶段的特征表示为 X i ∈ R H 2i ×H 2i ×2i+1C,i = 0, 1, 2。我们应用像素洗牌操作(Shi et al. 2016)进行上采样。为了减轻编码器中的信息损失并在解码器中很好地恢复特征,我们使用带有1×1卷积的加权跳跃连接来进行编码器和解码器之间的特征融合,这可以灵活地调整特征的贡献编码器和解码器。第三,在解码器之后,深度特征F依次经过三个变换器块和跨层注意融合块以生成用于图像重建的增强特征。最后,LLFormer 对增强的特征应用 3 × 3 卷积以产生增强的图像 ^I。我们使用平滑的 L1 损失来优化 LLFormer (Girshick 2015)。
Figure 2: LLFormer architecture. LLFormer 的核心设计包括基于轴的 Transformer 模块和跨层注意力融合模块。前者中,基于轴的多头自注意力在跨通道维度的高度和宽度轴上依次进行自注意力以降低计算复杂度,而双门控前馈网络采用门控机制来更多地关注有用的功能。跨层注意力融合块在融合不同层的特征时学习它们的注意力权重。
Axis-based Transformer Block
与 CNN 相比,Transformer 在建模非局部自相似性和远程依赖性方面具有优势。然而,正如(Vaswani et al. 2017;Liu et al. 2021c)中所讨论的,标准 Transformer 的计算成本与输入特征图的空间大小 (H ×W) 成二次方。此外,将变压器应用于高分辨率图像尤其是超高清图像通常是不可行的。为了解决这个问题,我们在变压器块中提出了一种基于轴的多头自注意力(A-MSA)机制。 A-MSA的计算复杂度与空间大小呈线性关系,大大降低了计算复杂度。此外,我们在普通变压器前馈网络中引入了双门控机制,并提出了双门控前馈网络(DGFN)来捕获特征中更重要的信息。我们将 A-MSA 和 DGFN 与普通变压器单元集成,以构建基于轴的变压器块 (ATB)。如图2所示,ATB包含A-MSA、DGFN和两个归一化层。 ATB的计算公式为:
Axis-based Multi-head Self-Attention.
标准自注意力的计算复杂度与输入分辨率成二次方,即 H × W 特征图的 O(W2H2) 。我们提出 A-MSA,而不是全局计算自注意力,如图 2 所示,顺序计算跨通道维度的高度和宽度轴上的自注意力。由于此操作,我们的 A-MSA 的复杂性降低为线性。此外,为了减轻 Transformer 在捕获局部依赖性方面的限制,我们采用深度卷积来帮助 A-MSA 在计算特征注意力图之前关注局部上下文(Zamir 等人,2022 年;Wang 等人,2022b)。由于高度轴多头自注意力机制和宽度轴多头自注意力机制相似,因此为了便于说明,我们仅介绍高度轴多头自注意力的细节。
Dual Gated Feed-Forward Network
之前的研究表明,前馈网络(FFN)在捕获本地上下文方面存在局限性(Vaswani 等人,2017 年;Dosovitskiy 等人,2021 年)。为了有效地进行特征转换,我们在 FFN 中引入了双门机制和局部信息增强,并提出了一种新颖的双门前馈网络(DGFN)。如图3(c)所示,对于双门控机制,我们首先在两个并行路径中应用双GELU和元素乘积来过滤信息量较少的特征,然后通过元素求和融合来自两个路径的有用信息。此外,我们在每个路径中应用 1 × 1 卷积(W1×1)和 3 × 3 深度卷积(W3×3)来丰富局部信息。给定 Y ∈ RH×W×C 作为输入,完整的 DGFN 公式为:
Cross-layer Attention Fusion Block
最近基于 Transformer 的方法采用特征连接或跳跃连接来组合来自不同层的特征(Zamir 等人,2022 年;Wang 等人,2022b)。然而,这些操作没有充分利用不同层之间的依赖关系,限制了表示能力。为了解决这个问题,我们提出了一种新颖的跨层注意融合块(CAFB),它自适应地融合层次特征与不同层之间的可学习相关性。 CAFB 背后的直觉是,不同层的激活是对特定类的响应,并且可以使用自注意力机制自适应地学习特征相关性。
其中 ^Fout 是关注网络信息层的输出特征。在实践中,我们将所提出的 CAFB 放置在网络中头部和尾部的对称位置,以便 CAFB 有助于捕获特征提取和图像重建过程中分层层之间的长距离依赖关系。
Experiments and Analysis
Implementation Details
LLFormer 在 128 × 128 个 patch 上进行训练,批量大小为 12。对于数据增强,我们采用水平和垂直翻转。我们使用 Adam 优化器,初始学习率为 10−4,并使用余弦退火将其降低至 10−6。 LLFormer中从阶段1到阶段4的编码器块的数量为{2,4,8,16},A-MSA中的注意力头的数量为{1,2,4,8}。第1级到第3级解码器对应的编号为{2,4,8}和{1,2,4}。为了进行基准测试,我们比较了 16 种代表性的 LLIE 方法,包括 7 种传统的非学习方法(BIMEF (Ying, Li, and Gau 2017)、FEA (Dong et al. 2011)、LIME (Guo, Li, and Ling 2016)、MF (Fu et al. 2016a)、NPE (Wang et al. 2013)、SRIE (Fu et al. 2016b)、MSRCR (Jobson, Rahman, and Woodell 1997))、三种基于 CNN 的监督方法 (RetinexNet (Wei et al.) . 2018)、DSLR(Lim 和 Kim 2020)、KinD(Zhang、Zhang 和Guo 2019))、两种基于 CNN 的无监督方法(ELGAN(Jiang 等人,2021)、RUAS(Liu 等人,2021b))、两种基于零样本学习的方法(Z_DCE(Guo 等人,2020)、Z_DCE++(Li、Guo 和 Loy 2021))和两种基于监督变压器的方法(Uformer(Wang 等人,2022b)、Restormer(Zamir 等人)等2022))。对于每种方法,我们都使用公开可用的代码,并对每种基于学习的方法进行 300 轮训练。对于ELGAN,我们直接使用其预训练的模型进行测试。使用 PSNR、SSIM、LPIPS 和 MAE 指标评估性能。
Benchmarking Study for UHD-LLIE
UHD-LOL4K Subset.
我们在 UHDLOL4K 子集上测试了 16 种不同的最先进的 LLIE 方法和我们提出的 LLFormer。定量结果如表1所示。根据表1,我们可以发现传统的LLIE算法(BIMEF、FEA、LIME、MF、NPE、SRIE、MSRCR)通常在UHD-LOL4K上效果不佳。其中,一些方法的定量得分(PSNR、SSIM、LPIPS、MAE)甚至比无监督学习方法(RUAS、ELGAN)还要差。基于 CNN 的监督学习方法(参见 RetiunexNet、DSLR 和 KID)的结果优于基于无监督学习和基于零样本学习的方法,这是预期的。在基于 CNN 的方法中,DSLR 在 PSNR、SSIM、LPIPS 和 MAE 方面获得了最佳性能。与基于CNN的监督学习方法相比,基于Transformer的监督学习方法(Uformer、Restormer和LLFormer)的性能得到了很大的提高。其中,所提出的 LLFormer 获得了最佳性能,与 Restormer 相比,PSNR 提高了 0.42 dB。视觉比较如图5所示。LLFormer恢复的图像色彩鲜艳,更接近真实情况。
UHD-LOL8K Subset.
我们还通过将每个 8K 图像划分为 4K 分辨率的 4 个块,对 UHD-LOL8K 子集进行基准测试实验。表1的最后四列显示了评估结果。深度学习方法 RetinexNet、DSLR、Uformer、Restormer 和 LLFormer 在像素方面和感知指标上都取得了更好的性能。基于 Transformer 的方法在所有评估指标上都名列前茅,LLFormer 的性能优于其他方法。如图 5 所示,LLFormer 产生了具有更多细节的视觉上令人愉悦的结果。
Improving Downstream Tasks.
为了验证 LLIE 是否有利于下游任务,我们从 DARK FACE 数据集(Yang et al. 2020)中随机选择 300 张图像,并使用基准研究中的前三种方法对这些图像进行预处理。然后,我们使用 RetinaFace 检测人脸(Deng et al. 2020)。使用预处理步骤时,Uformer、Restormer 和 LLFormer 的平均精度 (AP) 值分别提高了 67.06%、68.11% 和 71.2%。视觉结果如图6所示。预训练的LLIE模型不仅生成具有足够色彩平衡的图像,还有助于提高下游任务的性能。
Comparison results on Public Datasets
我们在 LOL(Wei 等人,2018 年)和 MIT-Adobe FiveK(Bychkovsky 等人,2011 年)数据集上对 LLFormer 进行基准测试,将其与专门为 LLIE 设计的 14 种方法和两种基于 Transformer 的方法进行比较。我们使用已发布的代码分别在这些数据集上重新训练 Uformer 和 Restormer。结果如表 2 所示。LLFormer 在 LOL 数据集上取得了显着更高的性能,获得了比 Restormer 更高的 PSNR、SSIM 和 MAE 分数。在LPIPS方面,LLFormer排名第二。在 MIT-Adobe FiveK 数据集上,基于 Transformer 的方法排名靠前,LLFormer 在所有指标上都取得了最佳结果。在最好的三种基于 Transformer 的方法中,Uformer、Restormer 和 LLFormer 的开销(参数和乘法累加运算)分别为 38.82M/76.67G、26.10M/140.99G 和 24.52M/22.52G(在 256 × 256 图像上测量) ), 分别。这表明所提出的 LLFormer 通过有效利用资源实现了最佳性能。这是由于 LLFormer 的设计,基于轴的多头自注意力和分层结构有助于降低计算复杂度。视觉比较如图 7 所示。LLFormer 生成具有足够饱和度以及颜色和纹理保真度的图像
Ablation Studies
我们通过测量以下因素的贡献来进行消融研究:(1)基于轴的多头自注意力; (2) 双门控前馈网络; (3)加权跳跃连接; (4)跨层注意力融合块; (5)网络的宽度和深度。在 UHD-LOL4K 子集上进行实验,并在大小为 128 × 128 的图像块上训练模型 100 个时期。
A. Axis-based Transformer Block.
我们测量了所提出的基于轴的多头自注意力和双门控前馈网络(FFN)的影响,请参见表 3。与使用 Resblock 的基本模型(Lim 等人,2017 年)相比,我们的 A-MSA (高度或宽度)和 DGFN 对改进做出了显着贡献。当使用深度卷积来增强自注意力(比较(d)和(h))或前馈网络(比较(f)和(h))中的局部性时,PSNR方面的改进分别是0.89、0.75 , 分别。通过应用双门控机制,PSNR和SSIM分别提高了3.42和0.0081(参见(g)(h))。将双门控机制与局部性结合使用可产生最佳结果。相反,将 A-MSA 与传统 FFN 相结合(Vaswani 等人,2017 年)会降低性能(表 3 (e))。这表明设计合适的 FFN 对于变压器块至关重要。
B. Skip Connection and Fusion Block.
为了验证加权连接和跨层注意融合块,我们通过逐步删除相应的组件来进行消融研究:(1)跳过,(2)1×1卷积,(3)带有1×1卷积的跳过,(4)头部 CAFB,(5) 尾部 CAFB,(6) 所有 CAFB。表 4 显示了 PSNR 和 SSIM 方面的结果,这表明每个组件都有助于改善结果。当包含 CAFB 和加权跳跃连接时,模型显着改进。当应用 1 × 1 卷积时,我们观察到了微小的增益。
C. Wider vs. Deeper.
为了了解网络中宽度和深度的影响,我们进行了消融实验,以逐渐增加 LLFormer 的宽度(通道)和深度(编码器级数)。表 5 显示了开销、性能和速度方面的结果。结果表明,与更宽或更深的同类相比,LLFormer 在性能和复杂性之间实现了最佳权衡(36.20/0.9867、22.52G、24.52M、0.063s)。
Conclusion
在本文中,我们构建了第一个大规模低光超高清图像增强基准数据集,该数据集由 UHDLOL4K 和 UHD-LOL8K 子集组成。基于该数据集,我们对UHD-LLIE进行了全面的实验。据我们所知,这是专门解决 UHD-LLIE 任务的第一次尝试。我们为 UHD-LLIE 提出了第一个基于变压器的基线网络,称为 LLFormer。大量实验表明 LLFormer 的性能明显优于其他最先进的方法。 UHD-LOL 数据集与 LLFormer 一起作为 LLIE 和 UHD-LLIE 任务的基准将使社区受益。
代码解读
Cross-layer Attention Fusion Block
#### Cross-layer Attention Fusion Block
class LAM_Module_v2(nn.Module):
""" Layer attention module"""
def __init__(self, in_dim,bias=True):
super(LAM_Module_v2, self).__init__()
self.chanel_in = in_dim
self.temperature = nn.Parameter(torch.ones(1))
self.qkv = nn.Conv2d( self.chanel_in , self.chanel_in *3, kernel_size=1, bias=bias)
self.qkv_dwconv = nn.Conv2d(self.chanel_in*3, self.chanel_in*3, kernel_size=3, stride=1, padding=1, groups=self.chanel_in*3, bias=bias)
self.project_out = nn.Conv2d(self.chanel_in, self.chanel_in, kernel_size=1, bias=bias)
def forward(self,x):
"""
inputs :
x : input feature maps( B X N X C X H X W)
returns :
out : attention value + input feature
attention: B X N X N
"""
m_batchsize, N, C, height, width = x.size()
x_input = x.view(m_batchsize,N*C, height, width)
qkv = self.qkv_dwconv(self.qkv(x_input))
q, k, v = qkv.chunk(3, dim=1)
q = q.view(m_batchsize, N, -1)
k = k.view(m_batchsize, N, -1)
v = v.view(m_batchsize, N, -1)
q = torch.nn.functional.normalize(q, dim=-1)
k = torch.nn.functional.normalize(k, dim=-1)
attn = (q @ k.transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
out_1 = (attn @ v)
out_1 = out_1.view(m_batchsize, -1, height, width)
out_1 = self.project_out(out_1)
out_1 = out_1.view(m_batchsize, N, C, height, width)
out = out_1+x
out = out.view(m_batchsize, -1, height, width)
return out
输入x多了维度N,是通过连续三个transformer block分别得到,需要注意的是q维度改变,由B N*C H W改为B N C*H*W,所以attn为B N N,表示跨层注意fusion图,跳连,1x1的卷积project_out。(代码是先1x1卷积再加残差,和原文示意图顺序相反,代码应该有误)
Axis-based Multi-head Self-Attention
class NextAttentionImplZ(nn.Module):
def __init__(self, num_dims, num_heads, bias) -> None:
super().__init__()
self.num_dims = num_dims
self.num_heads = num_heads
self.q1 = nn.Conv2d(num_dims, num_dims * 3, kernel_size=1, bias=bias)
self.q2 = nn.Conv2d(num_dims * 3, num_dims * 3, kernel_size=3, padding=1, groups=num_dims * 3, bias=bias)
self.q3 = nn.Conv2d(num_dims * 3, num_dims * 3, kernel_size=3, padding=1, groups=num_dims * 3, bias=bias)
self.fac = nn.Parameter(torch.ones(1))
self.fin = nn.Conv2d(num_dims, num_dims, kernel_size=1, bias=bias)
return
def forward(self, x):
# x: [n, c, h, w]
n, c, h, w = x.size()
n_heads, dim_head = self.num_heads, c // self.num_heads
reshape = lambda x: einops.rearrange(x, "n (nh dh) h w -> (n nh h) w dh", nh=n_heads, dh=dim_head)
qkv = self.q3(self.q2(self.q1(x)))
q, k, v = map(reshape, qkv.chunk(3, dim=1))
q = F.normalize(q, dim=-1)
k = F.normalize(k, dim=-1)
# fac = dim_head ** -0.5
res = k.transpose(-2, -1)
res = torch.matmul(q, res) * self.fac
res = torch.softmax(res, dim=-1)
res = torch.matmul(res, v)
res = einops.rearrange(res, "(n nh h) w dh -> n (nh dh) h w", nh=n_heads, dh=dim_head, n=n, h=h)
res = self.fin(res)
return res
### Axis-based Multi-head Self-Attention (row and col attention)
class NextAttentionZ(nn.Module):
def __init__(self, num_dims, num_heads=1, bias=True) -> None:
super().__init__()
assert num_dims % num_heads == 0
self.num_dims = num_dims
self.num_heads = num_heads
self.row_att = NextAttentionImplZ(num_dims, num_heads, bias)
self.col_att = NextAttentionImplZ(num_dims, num_heads, bias)
return
def forward(self, x: torch.Tensor):
assert len(x.size()) == 4
x = self.row_att(x)
x = x.transpose(-2, -1)
x = self.col_att(x)
x = x.transpose(-2, -1)
return x
q、k、v通过x进行1x1和3x3卷积得到,经过数组重拍,使得宽度成为通道轴,这样乘积的时候排除宽度方向影响,q乘以k的转置res为高度方向注意力,再进行col_att前,进行x转置,则又做了次宽度方向注意力计算。复杂度通过两次单方向求取注意力而降低。(建议参考原文示意图理解)
Dual Gated Feed-Forward Networ
class FeedForward(nn.Module):
def __init__(self, dim, ffn_expansion_factor, bias):
super(FeedForward, self).__init__()
hidden_features = int(dim*ffn_expansion_factor)
self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
def forward(self, x):
x = self.project_in(x)
x1, x2 = self.dwconv(x).chunk(2, dim=1)
x = F.gelu(x2)*x1 + F.gelu(x1)*x2
x = self.project_out(x)
return x
两个并行路径中应用双GELU和元素乘积来过滤信息量较少的特征,然后通过元素求和融合来自两个路径的有用信息