TransMEF: A Transformer-Based Multi-Exposure Image Fusion FrameworkUsing Self-Supervised Multi-Task

Abstract

在本文中,我们提出了 TransMEF,一种基于 Transformer 的多重曝光图像融合框架,使用自监督多任务学习。 该框架基于编码器-解码器网络,可以在大型自然图像数据集上进行训练,并且不需要地面实况融合图像。 根据多曝光图像的特点,我们设计了三个自监督重建任务,并利用多任务学习同时进行这些任务; 通过这个过程,网络可以学习多重曝光图像的特征并提取更通用的特征。 此外,为了弥补基于 CNN 的架构中建立远程依赖关系的缺陷,我们设计了一种将 CNN 模块与 Transformer 模块相结合的编码器。 这种组合使网络能够关注本地和全球信息。 我们在最新发布的多重曝光图像融合基准数据集上评估了我们的方法,并将其与 11 种具有竞争力的传统和基于深度学习的方法进行了比较,我们的方法在主观和客观评估中都取得了最佳性能。 代码可在 https://github.com/miccaiif/TransMEF 获取。

1 Introduction

由于普通成像传感器的低动态范围(LDR),单张图像经常出现曝光不足或曝光过度的情况,无法描绘自然场景中亮度水平的高动态范围(HDR)。 多重曝光图像融合(MEF)技术通过将不同曝光的 LDR 图像融合为单个 HDR 图像,提供了一种经济有效的解决方案,因此广泛应用于移动设备的 HDR 成像(Reinhard et al. 2010;Hasinoff et al. 2016; 沉等人,2011)。

MEF的研究历史悠久,并提出了一系列传统方法(Li, Manjunath, and Mitra 1995; Liu and Wang 2015; Lee, Park, and Cho 2018; Ma and Wang 2015; Ma et al. 2017b ,一个)。 然而,它们的性能受到限制,因为弱手工表示的泛化性较低,并且对于不同的输入条件不稳健(Zhang 2021;Zhang et al. 2020b;Xu et al. 2020a)。

近年来,基于深度学习的算法逐渐成为MEF领域的主流。 在这些方法中,将具有不同曝光度的两幅源图像直接输入到融合网络中,并从网络的输出中获得融合图像。 融合网络可以使用地面实况融合图像以常见的监督方式进行训练(Zhang et al. 2020b;Wang et al. 2018;Li and Zhuang 2018),或者通过鼓励融合图像保留融合图像的不同方面来以无监督方式进行训练。 源图像中的重要信息(Xu et al. 2020a;Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 2017;Xu et al. 2020b;Zhang et al. 2020a;Ma et al. 2019)。 然而,有监督和无监督 MEF 方法都需要大量的多重曝光数据进行训练。 尽管许多研究人员(Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 2017;Cai、Gu 和 Zhang 2018;Zeng 等人 2014)收集了各种多重曝光数据集,但其数量无法与 ImageNet 等大型自然图像数据集相比。 (Deng 等人,2009)或 MS-COCO(Lin 等人,2014)。 缺乏大量训练数据通常会导致过度拟合或繁琐的参数优化。 此外,有监督的 MEF 方法对地面实况的需求也很高,但在该领域并不常见(Zhang 2021)。 一些研究人员合成地面实况图像(Wang et al. 2018)或使用其他方法的融合结果作为地面实况进行训练(Yin et al. 2020;Chen and Chuang 2020;Xu、Ma和Zhang 2020)。 然而,这些真实图像并不真实,使用它们会导致性能较差。

此外,所有现有的基于深度学习的 MEF 方法都利用卷积神经网络(CNN)进行特征提取,但由于 CNN 的感受野较小,这是其固有的局限性,因此很难对长程依赖性进行建模。 在图像融合中,融合图像的质量与感受野内的像素以及整个图像的像素强度和纹理有关。 因此,需要对全局和局部依赖关系进行建模。

为了解决上述问题,我们提出了 TransMEF,一种基于 Transformer 的多重曝光图像融合框架,使用自监督多任务学习。 该框架基于编码器-解码器网络,并使用自监督图像重建在大型自然图像数据集上进行训练,以避免使用多重曝光图像进行训练。 在融合阶段,我们首先应用训练有素的编码器从两个源图像中提取特征图,然后应用训练有素的解码器从融合的特征图生成融合图像。 我们还根据多重曝光图像的特征设计了三个自监督重建任务来训练网络,使我们的网络更有效地学习这些特征。

此外,我们设计了一个包含 CNN 模块和 Transformer 模块的编码器,以便编码器可以利用局部和全局信息。 大量实验证明了自监督重建任务以及基于 Transformer 的编码器的有效性,并表明我们的方法在主观和客观评估方面均优于最先进的 MEF 方法。

本文的主要贡献总结如下:

• 我们根据多重曝光图像的特点提出了三个自监督重建任务,并使用多任务学习训练编码器-解码器网络,使我们的网络不仅能够在大型自然图像数据集上进行训练,而且还能够学习 多重曝光图像的特征

• 为了弥补基于CNN的架构中建立远程依赖关系的缺陷,我们设计了一种将CNN模块与变压器模块相结合的编码器,使网络在特征提取过程中能够利用局部和全局信息。

• 为了与其他融合方法进行公平、全面的比较,我们使用最新发布的多重曝光图像融合基准数据集(Zhang 2021)作为测试数据集。 我们从四个角度选择了 12 个客观评估指标,并将我们的方法与 MEF 领域 11 个有竞争力的传统和基于深度学习的方法进行了比较。 我们的方法在主观和客观评估中都达到了最佳性能。

2 Related Work

2.1 Traditional MEF Algorithms

传统的MEF方法可以进一步分为基于空间域的融合方法(Liu andWang 2015;Lee, Park, and Cho 2018;Ma and Wang 2015;Ma et al. 2017b,a)和基于变换域的融合方法(Li, Manjunath 和 Mitra 1995;Burt 和 Kolczynski 1993;Mertens、Kautz 和 Van Reeth 2007; 空间域方法直接从源图像的像素值计算融合图像的像素值,并且通常使用三种类型的技术来融合空间域中的图像,即基于像素的方法(Liu和Wang 2015;Lee、Park和 Cho 2018)、基于补丁的方法(Ma 和 Wang 2015;Ma 等人 2017b)和基于优化的方法(Ma 等人 2017a)。

在基于变换域的融合算法中,首先将源图像变换到特定的变换域(例如小波域)以获得不同的频率分量,然后使用适当的融合规则来融合不同的频率分量。 最后,通过对融合频率分量进行逆变换获得融合图像。 常用的变换方法包括金字塔变换(Burt and Kolczynski 1993)、拉普拉斯金字塔(Mertens、Kautz和Van Reeth 2007)、小波变换(Li、Manjunath和Mitra 1995)和边缘保留平滑(Kou et al. 2017) )等。

尽管传统方法取得了有希望的融合结果,但弱的手工表示和低泛化性阻碍了进一步的改进。

2.2 Deep-Learning Based MEF Algorithms

在基于深度学习的算法中,两张曝光不同的源图像直接输入到融合网络中,网络输出融合图像。 融合网络可以使用地面实况融合图像(Zhang et al. 2020b;Wang et al. 2018;Li and Zhuang 2018)或基于相似性度量的损失函数(Xu et al. 2020a;Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 2017;Xu 等人 2020b;Ma 等人 2019。

由于MEF领域缺乏真实的groundtruth融合图像,人们提出了几种合成groundtruth的方法。 例如,王等人。 (Wang et al. 2018)通过改变正常图像的像素强度生成地面实况数据(Deng et al. 2009),而其他研究人员利用其他 MEF 方法的融合结果作为地面实况(Yin 等人,2020;Chen 和 庄 2020;徐、马和张 2020)。 然而,这些地面实况图像并不真实,导致融合性能较差。

除了使用地面实况图像进行训练之外,另一个研究方向是使用基于相似度度量的损失函数来训练融合网络,以鼓励融合图像保留来自源图像不同方面的重要信息。 例如,普拉巴卡尔等人。 (Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 2017)应用无参考图像质量度量 (MEF-SSIM) 作为损失函数。 张等人。 (Zhang et al. 2020a)设计了一个基于梯度和强度信息的损失函数来执行无监督训练。 徐等人。 (Xu et al. 2020a,b) 提出了 U2Fusion,其中训练融合网络以保留融合结果与源图像之间的自适应相似性。 尽管这些方法不需要地面实况图像,但仍然需要大量的多重曝光图像来进行训练。 尽管已经收集了多个多重曝光数据集(Ram Prabhakar、Sai Srikar和Venkatesh Babu 2017;Cai、Gu和Zhang 2018;Zeng等人2014),但它们的数量无法与大型自然图像数据集(例如ImageNet(Deng) 等人,2009 年)或 MS-COCO(林等人,2014 年)。 缺乏大量训练数据会导致过度拟合或繁琐的参数优化。

值得注意的是,研究人员已经在红外和可见光图像融合(Li and Wu 2018)以及多焦点图像融合任务(Ma et al. 2021)中使用了编码器解码器网络。 他们在自然图像数据集上训练编码器-解码器网络,但由于域差异,网络无法有效学习多重曝光图像的特征。 相比之下,我们根据多重曝光图像的特征设计了三个自监督重建任务,这样我们的网络不仅可以在大型自然图像数据集上进行训练,而且还能够学习多重曝光图像的特征。

3 Method

3.1 Framework Overview

如图 1 所示,我们的框架是基于编码器-解码器的架构。 我们通过大型自然图像数据集的图像重建来训练网络。 在融合阶段,我们应用经过训练的编码器从一对源图像中提取特征图,然后融合两个特征图并将其输入到解码器中以生成融合图像。

框架训练过程如图1(a)所示,其中我们使用网络执行自监督图像重建任务,即从被破坏的输入图像重建原始图像。 具体来说,给定原始图像 Iin ∈ RH×W,通过使用三种不同变换之一(基于伽玛的变换 TG(· )、基于傅立叶的变换 TF(·) 和全局区域改组 TS(·))。 被破坏的图像被输入到Encoder中,Encoder由特征提取模块TransBlock和特征增强模块EnhanceBlock组成。 TransBlock 使用 CNN 模块和 Transformer 模块进行特征提取。 被破坏的图像〜IN in直接输入到CNN模块,同时将它们分成块〜IN p,然后输入到Transformer-Module。 EnhanceBlock 聚合并增强从 CNN 模块和 Transformer 模块提取的特征图。 最后,Decoder利用Encoder提取的图像特征得到重建图像IN out(N = 1, 2, 3)。 我们利用多任务学习方法,同时执行三个自我监督的重建任务。 编码器的详细结构和三个重建任务分别在3.2和3.3节中介绍。

训练好的Encoder和Decoder随后用于图像融合,如图1(b)所示。 具体来说,首先将两幅源图像Ik(k = 1, 2)输入Encoder进行特征编码,然后使用Fusion Rule将提取的特征图F1和F2融合,得到融合后的特征图F。 最后,融合图像IF由解码器重建。 融合规则在3.4节中有详细描述。 这里,我们只介绍该框架的单通道灰度图像融合流程,彩色图像的融合在3.5节中详细阐述。

3.2 Transformer-Based Encoder-Decoder Framework

Encoder-Decoder Framework for Image Reconstruction

编码器-解码器网络如图1(a)所示。 以单个自监督重建任务为例,给定训练图像 Iin ∈ RH×W,我们首先随机生成 10 个图像子区域 xi ∈ RHi×Wi , (i = 1, 2, ... 10) 形成 将被变换的集合 χ = {x1, x2, . 。 。 ,x10},其中子区域Hi、Wi的大小都是从正整数集合[1,25]中均匀采样的随机值。 之后,我们使用为多曝光图像融合量身定制的图像变换对集合 χ 中的每个子区域 xi 进行变换(这三种不同的变换在 3.3 节中详细描述),以获得变换后的子区域集合 χ,然后将其使用 替换原来的子区域以获得变换后的图像〜Iin。 在图1(a)中,TG(·)、TF(·)和TS(·)分别表示基于伽玛变换、傅立叶变换和全局区域混洗的变换。

Encoder包含特征提取模块TransBlock和特征增强模块EnhanceBlock。 TransBlock的详细架构将在下面介绍。 特征增强模块EnhanceBlock对TransBlock提取的特征图进行聚合和增强,使得Encoder能够更好地融合全局和局部特征。 具体来说,我们将 TransBlock 中 CNN 模块和 Transformer 模块的两个特征图连接起来,并将它们输入到两个顺序连接的 ConvBlock 层中以实现特征增强。 如图 1 (c) 所示,每个 ConvBlock 由两个内核大小为 3×3 的卷积层、填充为 1 和一个 ReLU 激活层组成。 Decoder 包含两个顺序连接的 ConvBlock 层和最终的 1×1 卷积来重建原始图像。

TransBlock: A Powerful Feature Extractor

受 TransUnet (Chen et al. 2021) 和 ViT (Dosovitskiy et al. 2020) 的启发,我们提出了一种特征提取模块 TransBlock,它结合了 CNN 和 Transformer 架构来对图像中的局部和全局依赖性进行建模。 TransBlock的架构如图1(a)所示。 具体来说,CNN-Module 由三个顺序连接的 ConvBlock 层组成,CNN-Module 的输入是被破坏的图像。 同时,被破坏的图像I~in ε RH×W被划分为总共M个大小为H P ×W P 的块。 这些补丁用于构造序列 xseq ∈ RM×P2 ,其中 xseq =  xk p  , (k = 1, 2, . . .M),M = HW/P2 且 P 是补丁的大小。 该序列被输入到 Transformer-Module 中,该模块以块嵌入线性投影 E 开始,并获得编码的序列特征 z0 ∈ RM×D。 然后,z0经过L个Transformer层,每层的输出表示为zl(l = 1...L)。 图 1 (c) 说明了一个 Transformer 层的架构,该层由多头注意机制 (MSA) 块和多层感知器 (MLP) 块组成,其中在每个块和残差之前应用层归一化 (LN) 连接在每个块之后应用。 MLP 块由两个具有 GELU 激活函数的线性层组成。

Loss Function

我们的架构采用多任务学习方法,使用以下损失函数同时执行三个自监督重建任务。

其中 Loss 表示整体损失函数。 LossTaskG 、 LossTaskF 和 LossTaskS 是三个自监督重建任务的损失函数。

在每个重建任务中,我们鼓励网络不仅学习像素级图像重建,还捕获图像中的结构和梯度信息。 因此,每个重建任务的损失包含三部分,定义如下:

其中Lmse是均方误差(MSE)损失函数,Lssim是结构相似性(SSIM)损失函数,LTV是总变差损失函数。 λ1 和 λ2 是两个超参数,根据经验设置为 20。

MSE损失用于确保像素级重建,定义如下:

其中 Iout 是输出,即网络重建的融合图像,Iin 表示输入,即原始图像。

SSIM损失帮助模型更好地从图像中学习结构信息,定义为:

VIFNet(Hou et al. 2020)中引入的总变异损失 LTV 用于更好地保留源图像中的梯度并进一步消除噪声。 它的定义如下:

其中R(p,q)表示原始图像和重建图像之间的差异,|| · ||2 表示 l2 范数,p、q 分别表示图像像素的横坐标和纵坐标。

3.3 Three Specific Self-Supervised Image Reconstruction Tasks

在本节中,我们介绍三种变换,它们会破坏原始图像并生成图像重建编码器-解码器网络的输入。 补充材料第 1 节中提供了一个示例,显示了变换前后的图像和相应的子区域。

(1) Learning Scene Content and Luminance Information using Gamma-based Transformation.

一般来说,曝光过度的图像在黑暗区域包含足够的内容和结构信息,而曝光不足的图像在明亮区域包含足够的颜色和结构信息。 在融合图像中,需要保持均匀的亮度,同时保留所有区域的丰富信息(Xu et al. 2020a;Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 2017)。 我们采用伽玛变换来改变原始图像的几个子区域的亮度,并训练网络来重建原始图像。 在此过程中,我们的网络从不同亮度级别的图像中学习内容和结构信息。

伽马变换定义为:

其中 ψ 和 ψ 分别是变换后的像素值和原始像素值。 对于所选子区域 xi 中的每个像素,我们使用随机伽玛变换 Γ 来改变亮度,其中伽玛是从区间 [0, 3] 均匀采样的随机值。

(2) Learning Texture and Detail Information using Fourier-based Transformation.

我们引入了基于傅里叶变换的自监督任务,使网络能够从频域学习纹理和细节信息。

在图像的离散傅立叶变换(DFT)中,幅度谱决定图像的强度,而相位谱主要决定图像的高级语义,包含有关图像内容和物体位置的信息。 (有关进一步的描述和实验,请参阅补充材料第 1.2 节)。

由于曝光时间不足而导致曝光不足图像太暗,而由于曝光时间过长导致曝光过度图像太亮,这两种情况都会导致图像强度分布不合适。 因此,鼓励网络从图像中学习适当的强度分布至关重要。

尽管曝光不足和曝光过度的图像的强度分布都很差,但图像中物体的形状和内容仍然很好地包含在相位谱中。 因此,在这种情况下建立一个可以捕获形状和内容信息的网络是有益的。

为此,对于选定的图像子区域,我们首先进行傅里叶变换以获得幅度谱和相位谱。 然后,我们破坏频域中的子区域。 具体来说,采用高斯模糊(σ = 0.5)来改变幅度谱,并对相位谱中的所有相位值进行np次随机交换,其中np是正整数集合[1, 5]中的随机数。

(3) Learning Structure and Semantic Information using Global Region Shuffling

我们引入全局区域改组变换(Kang et al. 2017)来破坏原始图像,从而使网络能够通过图像重建来学习结构和语义信息。 具体来说,对于在原始图像 Iin 中选择的图像子区域 χ 集合中的每个图像子区域 xi,我们随机选择另一个与 xi 大小相同的图像子区域 x i 。 之后,将它们交换并重复该过程10次,以获得被破坏的图像。

3.4 Fusion Rule

由于我们的网络已经具有很强的特征提取能力,我们简单地对两个源图像F1和F2的特征图进行平均,得到融合后的特征图F,然后将其转发到解码器。

3.5 Managing RGB Input

我们采用了之前基于深度学习的研究中常用的策略来融合 RGB 多重曝光图像(Zhang et al. 2020a)。 彩色图像的 RGB 通道首先转换为 YCbCr 颜色空间。 然后,使用我们的网络融合Y(亮度)通道,并使用传统的加权平均方法融合Cb和Cr(色度)通道中的信息,定义为:

其中 C1 和 C2 表示多重曝光图像的 Cb(或 Cr)通道值。 Cf 表示它们的融合通道结果,其中 τ 设置为 128。最后,融合的 Y 通道、Cb 和 Cr 通道被转换回 RGB 空间。

4 Experiments and Results

4.1 Datasets

我们使用大型自然数据集 MS-COCO (Lin et al. 2014) 来训练编码器-解码器网络。 MS-COCO包含超过70,000张各种场景的自然图像。 为了方便起见,所有图像都调整为 256 × 256 并转换为灰度图像。 值得一提的是,尽管已经提出了许多有竞争力的 MEF 算法,但它们并没有在统一的 MEF 基准上进行评估。 我们使用最新发布的多重曝光图像融合基准数据集(Zhang 2021)作为测试数据集。 该基准数据集由 100 对具有各种场景和多个对象的多重曝光图像组成。

4.2 Implementation Details

我们的网络在 NVIDIA GTX 3090 GPU 上进行训练,批量大小为 64 和 70 epoch。 我们使用 ADAM 优化器和余弦退火学习率调整策略,学习率为 1e-4,权重衰减为 0.0005。 对于 256 × 256 的训练图像,我们随机生成 10 个随机大小的子区域,以形成要变换的集合 χ。 在TransBlock中,我们将变换后的输入图像划分为大小为16×16的块并构造序列xseq。

4.3 Evaluation Metrics

我们使用主观和客观评估来严格评估我们的方法(Zhang 2021)。 主观评价是观察者对融合图像质量在清晰度、细节和对比度等方面的主观评价。 在客观评价中,为了与其他融合方法进行公正、全面的比较,我们从四个角度选取了12个客观评价指标。 其中包括基于信息论的指标、QMI、QTE、QNICE、PSNR、FMI; 基于图像特征的度量,QA/BF、QP、STD、QG; 基于图像结构相似性的度量,SSIM,CC; 和人类感知启发的指标,VIF。 有关指标的详细信息可以在补充材料第 3 节中找到。所有客观指标均按 100 个融合图像的平均值计算,值越大表示所有指标的性能越好。

我们将我们的方法与 11 种竞争性传统方法 (Li, Manjunath, and Mitra 1995; Liu and Wang 2015; Lee, Park, and Cho 2018; Ma and Wang 2015; Ma et al. 2017b,a) 和深度学习方法 (Zhang 等人 2020b;Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 等人 2020b;Ma 等人 2019。 比较方法如下: 传统方法包括:DWT (Li, Manjunath, and Mitra 1995)、DSIFT-EF (Liu and Wang 2015)、MEFAW (Lee, Park, and Cho 2018)、PWA (Ma and Wang 2015) )、SPD-MEF(Ma et al. 2017b)和 MEFOpt(Ma et al. 2017a),基于深度学习的方法包括 Deepfuse(Ram Prabhakar、Sai Srikar 和 Venkatesh Babu 2017)、MEFNet(Ma et al. 2017a)。 2019)、U2Fusion(Xu 等人,2020a)、PMGI(Zhang 等人,2020a)和 IFCNN(Zhang 等人,2020b)。

4.4 Subjective Evaluation

图 2 显示了我们的方法和竞争对手在室内和室外场景中的融合结果。 更多融合结果显示在补充材料第 4 节中。

当融合图 2 (a1) 和 (b1) 中的第一对源图像时,DSIFT-EF、MEFAW、MEFOpt、SPD-MEF 和 MEFNet 导致亮度维持效果令人失望,并且融合图像显得较暗。 PWA 引入了伪影,并且颜色不真实。 虽然DWT、Deepfuse、PMGI、IFCNN和U2Fusion保持了中等亮度,但它们的融合结果对比度较低,无法描绘图像的细节。 相比之下,我们的方法保持了最佳的亮度和对比度,同时显示出出色的细节和更好的视觉感知。

当融合图 2 (a2) 和 (b2) 中的第二对源图像时,大多数方法无法保持适当的亮度。 MEFNet 和 PMGI 保持相对较好的亮度,但会引入伪影和模糊。 显然,我们的方法保持了最佳的亮度和对比度,同时保留了更详细的信息。

4.5 Objective Evaluation

表 1 列出了基准数据集上所有比较方法的客观评价。 我们的方法在 12 个指标中的 9 个指标上实现了最佳性能,而对于其他三个指标,我们的方法的结果与最佳结果之间的差距很小。

5 Ablation Study

5.1 Ablation Study for TransBlock

为了验证 TransBlock 的有效性,我们使用 20% 的训练数据进行了消融研究,消融研究的结果如表 2 所示。无论是否使用所提出的自监督重建任务,添加 TransBlock 总是会提高 融合性能。

为了进一步解释为什么 TransBlock 是有效的,我们使用传统的 CNN 架构和包含 TransBlock 的模型来可视化图像重建的效果。 可以看出后者重构了更好的细节。 更多信息请参阅补充材料第 2 节。

5.2 Ablation Study for Three Specific Self-Supervised Image Reconstruction Tasks

在这项消融研究中,我们证明了每个自监督重建任务的有效性以及以多任务方式同时执行这些任务的优越性。 本研究使用20%的训练数据进行,实验结果如表3所示。结果表明,单独的每个自监督重建任务都可以提高融合性能,并且通过进行整体最佳性能 通过多任务学习同时完成三个任务。

6 Conclusion

在本文中,我们提出了 TransMEF,一种通过自监督多任务学习的基于 Transformer 的多重曝光图像融合框架。 TransMEF 基于编码器解码器结构,因此可以在大型自然图像数据集上进行训练。 TransMEF编码器集成了CNN模块和transformer模块,使网络可以同时关注局部和全局信息。 此外,我们根据多重曝光图像的特点设计了三个自监督重建任务,并利用多任务学习同时进行这些任务,以便网络在图像重建过程中学习这些特征。 大量的实验表明,与现有的竞争方法相比,我们的新方法在主观和客观评估方面都取得了最先进的性能。 所提出的 TransBlock 和自监督重建任务有可能应用于其他图像融合任务和图像处理的其他领域。

code

Network_TransMEF.py

# -*- coding: utf-8 -*-
# Citation:
# @article{qu2021transmef,
#   title={TransMEF: A Transformer-Based Multi-Exposure Image Fusion Framework using Self-Supervised Multi-Task Learning},
#   author={Qu, Linhao and Liu, Shaolei and Wang, Manning and Song, Zhijian},
#   journal={arXiv preprint arXiv:2112.01030},
#   year={2021}
# }
import torch
import torch.nn as nn
from torch import einsum
from einops import rearrange
from einops.layers.torch import Rearrange


def save_grad(grads, name):
    def hook(grad):
        grads[name] = grad

    return hook


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class OutConv(nn.Module):
    """1*1 conv before the output"""

    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)


class Encoder(nn.Module):
    """features extraction"""

    def __init__(self):
        super(Encoder, self).__init__()
        self.inc = DoubleConv(1, 16)
        self.layer1 = DoubleConv(16, 32)
        self.layer2 = DoubleConv(32, 48)

    def forward(self, x, grads=None, name=None):
        x = self.inc(x)
        x = self.layer1(x)
        x = self.layer2(x)

        if grads is not None:
            x.register_hook(save_grad(grads, name + "_x"))
        return x


class Encoder_Trans(nn.Module):
    """features extraction"""

    def __init__(self):
        super(Encoder_Trans, self).__init__()
        self.inc = DoubleConv(1, 16)
        self.layer1 = DoubleConv(17, 32)
        self.layer2 = DoubleConv(32, 48)
        self.transformer = ViT(image_size=256, patch_size=16, dim=256, depth=12, heads=16, mlp_dim=1024, dropout=0.1,
                               emb_dropout=0.1)

    def forward(self, x, grads=None, name=None):
        x_e = self.inc(x)
        x_t = self.transformer(x)
        x = torch.cat((x_e, x_t), dim=1)
        x = self.layer1(x)
        x = self.layer2(x)

        if grads is not None:
            x.register_hook(save_grad(grads, name + "_x"))
        return x


class Decoder(nn.Module):
    """reconstruction"""

    def __init__(self):
        super(Decoder, self).__init__()
        self.layer1 = DoubleConv(48, 32)
        self.layer2 = DoubleConv(32, 16)
        self.outc = OutConv(16, 1)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        output = self.outc(x)
        return output


class Decoder_Trans(nn.Module):
    """reconstruction"""

    def __init__(self):
        super(Decoder_Trans, self).__init__()
        self.layer3 = DoubleConv(49, 48)
        self.layer4 = DoubleConv(48, 48)
        self.layer1 = DoubleConv(48, 32)
        self.layer2 = DoubleConv(32, 16)
        self.outc = OutConv(16, 1)

    def forward(self, x):
        x = self.layer4(self.layer3(x))
        x = self.layer1(x)
        x = self.layer2(x)
        output = self.outc(x)
        return output


class SimNet(nn.Module):
    """easy network for self-reconstruction task"""

    def __init__(self):
        super(SimNet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn

    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, dim, depth, heads, mlp_dim, channels=1, dim_head=64,
                 dropout=0., emb_dropout=0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'

        patch_dim = channels * patch_size ** 2

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
            nn.Linear(patch_dim, dim)
        )
        self.dim = dim
        self.patch_size = patch_size
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.convd1 = nn.Sequential(
            nn.Conv2d(1, 4, kernel_size=3, padding=1),
            nn.ReLU(inplace=True))

    def forward(self, img):
        x = self.to_patch_embedding(img)  # [B,256,256]
        b, n, _ = x.shape

        x = self.transformer(x)
        x = Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1=self.patch_size, h=16, c=1)(x)  # [B,1,256,256]

        return x


class TransNet(nn.Module):
    """U-based network for self-reconstruction task"""

    def __init__(self):
        super(TransNet, self).__init__()

        self.encoder = Encoder_Trans()
        self.decoder = Decoder()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

dataloader_transMEF.py

# -*- coding: utf-8 -*-
# Citation:
# @article{qu2021transmef,
#   title={TransMEF: A Transformer-Based Multi-Exposure Image Fusion Framework using Self-Supervised Multi-Task Learning},
#   author={Qu, Linhao and Liu, Shaolei and Wang, Manning and Song, Zhijian},
#   journal={arXiv preprint arXiv:2112.01030},
#   year={2021}
# }
from __future__ import print_function
import torch.utils.data as Data
import torchvision.transforms as transforms
import numpy as np
from glob import glob
import os
import copy
from PIL import Image
import random
from imgaug import augmenters as iaa

sometimes = lambda aug: iaa.Sometimes(0.8, aug)
np.random.seed(2)


def local_pixel_shuffling(x):
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)
    img_rows, img_cols = x.shape
    num_block = 10
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows // 10)
        block_noise_size_y = random.randint(1, img_cols // 10)
        noise_x = random.randint(0, img_rows - block_noise_size_x)
        noise_y = random.randint(0, img_cols - block_noise_size_y)
        window = orig_image[noise_x:noise_x + block_noise_size_x,
                 noise_y:noise_y + block_noise_size_y]
        window = window.flatten()
        np.random.shuffle(window)
        window = window.reshape((block_noise_size_x,
                                 block_noise_size_y))
        image_temp[noise_x:noise_x + block_noise_size_x,
        noise_y:noise_y + block_noise_size_y] = window
    local_shuffling_x = image_temp

    return local_shuffling_x


def global_patch_shuffling(x):
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)

    img_rows, img_cols = x.shape
    num_block = 10
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows // 10)
        block_noise_size_y = random.randint(1, img_cols // 10)

        noise_x1 = random.randint(0, img_rows - block_noise_size_x)
        noise_y1 = random.randint(0, img_cols - block_noise_size_y)

        noise_x2 = random.randint(0, img_rows - block_noise_size_x)
        noise_y2 = random.randint(0, img_cols - block_noise_size_y)

        window1 = orig_image[noise_x1:noise_x1 + block_noise_size_x,
                  noise_y1:noise_y1 + block_noise_size_y]
        window2 = orig_image[noise_x2:noise_x2 + block_noise_size_x,
                  noise_y2:noise_y2 + block_noise_size_y]

        window_tmp = window1
        window1 = window2
        window2 = window_tmp

        image_temp[noise_x1:noise_x1 + block_noise_size_x,
        noise_y1:noise_y1 + block_noise_size_y] = window1
        image_temp[noise_x2:noise_x2 + block_noise_size_x,
        noise_y2:noise_y2 + block_noise_size_y] = window2

    local_shuffling_x = image_temp

    return local_shuffling_x


def brightness_aug(x, gamma):
    aug_brightness = iaa.Sequential(sometimes(iaa.GammaContrast(gamma=gamma)))
    aug_image = aug_brightness(images=x)
    return aug_image


def bright_transform(x):
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)
    img_rows, img_cols = x.shape
    num_block = 10
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows // 10)
        block_noise_size_y = random.randint(1, img_cols // 10)
        noise_x = random.randint(0, img_rows - block_noise_size_x)
        noise_y = random.randint(0, img_cols - block_noise_size_y)
        window = orig_image[noise_x:noise_x + block_noise_size_x,
                 noise_y:noise_y + block_noise_size_y]
        window = brightness_aug(window, 3 * np.random.random_sample())

        image_temp[noise_x:noise_x + block_noise_size_x,
        noise_y:noise_y + block_noise_size_y] = window
    bright_transform_x = image_temp

    return bright_transform_x


def fourier_broken(x, nb_rows, nb_cols):
    aug_a = iaa.GaussianBlur(sigma=0.5)
    aug_p = iaa.Jigsaw(nb_rows=nb_rows, nb_cols=nb_cols, max_steps=(1, 5))
    fre = np.fft.fft2(x)
    fre_a = np.abs(fre)
    fre_p = np.angle(fre)
    fre_a_normalize = fre_a / (np.max(fre_a) + 0.0001)
    fre_p_normalize = fre_p
    fre_a_trans = aug_a(image=fre_a_normalize)
    fre_p_trans = aug_p(image=fre_p_normalize)
    fre_a_trans = fre_a_trans * (np.max(fre_a) + 0.0001)
    fre_p_trans = fre_p_trans
    fre_recon = fre_a_trans * np.e ** (1j * (fre_p_trans))
    img_recon = np.abs(np.fft.ifft2(fre_recon))
    return img_recon


def fourier_transform(x):
    image_temp = copy.deepcopy(x)
    orig_image = copy.deepcopy(x)
    img_rows, img_cols = x.shape
    num_block = 10
    for _ in range(num_block):
        block_noise_size_x = random.randint(1, img_rows // 10)
        block_noise_size_y = random.randint(1, img_cols // 10)
        noise_x = random.randint(0, img_rows - block_noise_size_x)
        noise_y = random.randint(0, img_cols - block_noise_size_y)
        window = orig_image[noise_x:noise_x + block_noise_size_x,
                 noise_y:noise_y + block_noise_size_y]
        window = fourier_broken(window, block_noise_size_x, block_noise_size_y)
        image_temp[noise_x:noise_x + block_noise_size_x,
        noise_y:noise_y + block_noise_size_y] = window
    bright_transform_x = image_temp

    return bright_transform_x


class Fusionset(Data.Dataset):
    def __init__(self, io, args, root, transform=None, gray=True, partition='train', ssl_transformations=None):
        self.files = glob(os.path.join(root, '*.*'))
        self.gray = gray
        self._tensor = transforms.ToTensor()
        self.transform = transform
        self.ssl_transformations = ssl_transformations
        self.args = args

        if args.miniset == True:
            self.files = random.sample(self.files, int(args.minirate * len(self.files)))
        self.num_examples = len(self.files)

        if self.ssl_transformations == True:
            print('used ssl_transformations')
        else:
            print('not used ssl_transformations')

        if partition == 'train':
            self.train_ind = np.asarray([i for i in range(self.num_examples) if i % 10 < 8]).astype(np.int)
            np.random.shuffle(self.train_ind)
            self.val_ind = np.asarray([i for i in range(self.num_examples) if i % 10 >= 8]).astype(np.int)
            np.random.shuffle(self.val_ind)
        io.cprint("number of " + partition + " examples in dataset" + ": " + str(len(self.files)))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        img = Image.open(self.files[index])
        if self.transform is not None:
            img = self.transform(img)
        if self.gray:
            img = img.convert('L')
        img = np.array(img)

        if self.ssl_transformations == True:
            img_bright_orig = img.copy()
            img_bright_trans = bright_transform(img_bright_orig)
            img_bright_trans = self._tensor(img_bright_trans)

            img_fourier_orig = img.copy()
            img_fourier_trans = fourier_transform(img_fourier_orig)
            img_fourier_trans = self._tensor(img_fourier_trans)

            img_shuffling_orig = img.copy()
            img_shuffling_trans = global_patch_shuffling(img_shuffling_orig)
            img_shuffling_trans = self._tensor(img_shuffling_trans)
            img = self._tensor(img)
        else:
            img = self._tensor(img)
            img_bright_trans = img
            img_fourier_trans = img
            img_shuffling_trans = img

        return img, img_bright_trans, img_fourier_trans, img_shuffling_trans

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值