Retinexformer: One-stage Retinex-basedTransformer for Low-light Image Enhancement

Abstract

在增强低光图像时,许多深度学习算法都是基于 Retinex 理论。然而,Retinex 模型没有考虑隐藏在黑暗中或由点亮过程引入的损坏。此外,这些方法通常需要繁琐的多阶段训练流程,并依赖于卷积神经网络,在捕获远程依赖性方面表现出局限性。在本文中,我们制定了一个简单但有原则的基于 Retinex 的单阶段框架(ORF)。 ORF 首先估计照明信息以照亮低光图像,然后恢复损坏以生成增强的图像。我们设计了一个照明引导变压器(IGT),它利用照明表示来指导不同照明条件下区域的非局部交互的建模。通过将 IGT 插入 ORF,我们获得了我们的算法 Retinexformer。全面的定量和定性实验表明,我们的 Retinexformer 在 13 个基准测试中显着优于最先进的方法。用户对弱光物体检测的研究和应用也揭示了我们的方法的潜在实用价值。代码可在 https://github .com/caiyuanhao1998/Retinexformer上获取。

1. Introduction

低光图像增强是计算机视觉中一项重要但具有挑战性的任务。它的目的是改善低光图像的可见度差和对比度低的问题,并恢复隐藏在黑暗中或由点亮过程引入的损坏(例如噪声、伪影、颜色失真等)。这些问题不仅挑战人类视觉感知,还挑战夜间物体检测等其他视觉任务。

因此,人们提出了大量的低光图像增强算法。然而,这些现有的算法都有其自身的缺点。直方图均衡和伽马校正等简单方法往往会产生不需要的伪影,因为它们几乎不考虑照明因素。传统的认知方法依赖于 Retinex 理论 [27],该理论假设彩色图像可以分解为两个组成部分,即反射率和照明度。与普通方法不同,传统方法侧重于照明估计,但通常会引入严重的噪声或局部颜色失真,因为这些方法假设图像无噪声和颜色失真。这与真实的曝光不足场景不一致。

随着深度学习的发展,卷积神经网络(CNN)已应用于弱光图像增强。这些基于CNN的方法主要分为两类。第一类直接使用 CNN 来学习从低光图像到正常光图像的强力映射函数,从而忽略了人类的颜色感知。这种方法缺乏可解释性和理论上证明的特性。第二类是受到Retinex理论的启发。这些方法 [54,65,66] 通常会受到多阶段训练流程的影响。他们分别采用不同的 CNN 来分解彩色图像、对反射率进行降噪并调整照明。这些 CNN 首先独立训练,然后连接在一起进行端到端微调。训练过程繁琐且耗时。

此外,这些基于 CNN 的方法在捕获长程依赖性和非局部自相似性方面表现出局限性,而这对于图像恢复至关重要。最近兴起的深度学习模型 Transformer 或许为解决基于 CNN 的方法的这一缺陷提供了可能。然而,直接应用原始视觉 Transformer 进行低光图像增强可能会遇到问题。计算复杂度与输入空间大小成二次方。这种计算成本可能是难以承受的。由于这一限制,一些 CNN-Transformer 混合算法(如 SNR-Net [57])仅在 U 形 CNN 的最低空间分辨率下采用单个全局 Transformer 层。因此,Transformer 低光图像增强的潜力仍未得到充分开发。

为了解决上述问题,我们提出了一种新的方法,Retinexformer,用于低光图像增强。首先,我们制定了一个简单但有原则的基于 Retinex 的单阶段框架(ORF)。我们通过向反射率和照明引入扰动项来对原始 Retinex 模型进行修改,以对腐败进行建模。我们的 ORF 估计照明信息并用它来照亮低光图像。然后 ORF 采用损坏恢复器来抑制噪声、伪影、曝光不足/过度以及颜色失真。与之前基于 Retinex 的深度学习框架存在繁琐的多阶段训练流程不同,我们的 ORF 以单阶段方式进行端到端训练。其次,我们提出了一种照明引导变压器(IGT)来对远程依赖性进行建模。 IGT 的关键组件是照明引导多头自注意力(IG-MSA)。 IG-MSA 利用照明表示来指导自注意力的计算并增强不同曝光水平区域之间的交互。最后,我们将 IGT 作为损坏恢复器插入 ORF,以导出我们的方法 Retinexformer。如图 1 所示,我们的 Retinexformer 在各种数据集上大幅超越了最先进(SOTA)的基于 Retinex 的深度学习方法。特别是在 SID [9]、SDSD [48]-室内和 LOL-v2 [59]-合成上,改进超过 6 dB。

我们的贡献可总结如下:

• 我们提出了第一个基于Transformer 的算法Retinexformer,用于低光图像增强。

• 我们制定了一种基于 Retinex 的单阶段低光增强框架 ORF,它具有简单的单阶段训练过程并能很好地对损坏进行建模。

• 我们设计了一种新的自注意力机制IG-MSA,它利用照明信息作为关键线索来指导远程依赖性建模。

• 定量和定性实验表明,我们的 Retinexformer 在 13 个数据集上优于 SOTA 方法。用户研究和弱光检测的结果也表明了我们的方法的实用价值。

2. Related Work

2.1. Low-light Image Enhancement

Plain Methods.

直方图均衡[1,8,12,40,41]和伽马校正(GC)[19,42,53]等简单方法直接放大曝光不足图像的低可见度和对比度。然而,这些方法几乎没有考虑光照因素,使得增强后的图像在感知上与真实的正常光场景不一致。

Traditional Cognition Methods.

与普通算法不同,传统方法[15,23,24,29,50]考虑了光照因素。他们依靠 Retinex 理论,将低光图像的反射率分量视为增强结果的合理解决方案。例如,郭等人[18]建议通过在其上施加先验结构来细化初始估计的照明图。然而,这些方法天真地假设低光图像没有损坏,导致增强过程中出现严重的噪声和颜色失真。另外,这些方法依赖于手工设计的先验,通常需要仔细的参数调整,并且泛化能力较差。

Deep Learning Methods.

随着深度学习的快速进步,CNN [16,17,22,33,35,38,45,49,61,66,68]已广泛应用于弱光图像增强。例如,魏等人[54]和后续工作[65, 66]将Retinex分解与深度学习结合起来。然而,这些方法通常面临繁琐的多阶段训练流程。多个 CNN 分别用于学习或调整 Retinex 模型的不同组件。王等人 [49]提出了一种基于 Retinex 的单级 CNN,称为 DeepUPE,来直接预测光照图。尽管如此,DeepUPE 并没有考虑损坏因素,导致在点亮曝光不足的照片时会出现放大的噪点和颜色失真。此外,这些基于 CNN 的方法在捕获不同区域的远程依赖性方面也表现出局限性。

2.2. Vision Transformer

[46] 中提出了用于机器翻译的自然语言处理模型 Transformer。近年来,Transformer 及其变体已应用于许多计算机视觉任务,并在高级视觉方面取得了令人印象深刻的成果(例如图像分类 [2,4,14]、语义分割 [7,55,67]、目标检测) [3,13,62]等)和低级视觉(例如图像恢复[6,11,60],图像合成[20,21,64]等)。例如,徐等人[57]提出了一种 SNR 感知的 CNN-Transformer 混合网络 SNR-Net,用于低光图像增强。然而,由于普通全局 Transformer 的计算成本巨大,SNR-Net 仅在 U 形 CNN 的最低分辨率下采用单个全局 Transformer 层。 Transformer 在低光图像增强方面的潜力尚未得到充分开发。

3. Method

图 2 说明了我们方法的整体架构。如图 2 (a) 所示,我们的 Retinexformer 基于我们制定的单阶段基于 Retinex 的框架 (ORF)。 ORF 由照明估计器 (i) 和损坏恢复器 (ii) 组成。我们设计了一个照明引导变压器(IGT)来扮演腐败恢复者的角色。如图2(b)所示,IGT的基本单元是照明引导注意块(IGAB),它由两层归一化(LN)、照明引导多头自注意(IG-MSA)模块和前馈网络(FFN)。图2(c)显示了IG-MSA的细节。

3.1. One-stage Retinex-based Framework

根据Retinex理论。低光图像I \in R^{H \times W \times 3} 可以分解为反射率图像 R\in R^{H\times W \times 3} 和照度图 L \in R^{H\times W}

其中 ⊙ 表示逐元素乘法。该 Retinex 模型假设 I 是无损坏的,这与真实的曝光不足场景不一致。我们分析,腐败现象主要源于两个因素。首先,黑暗场景的高 ISO 和长时间曝光成像设置不可避免地会引入噪点和伪影。其次,点亮过程可能会放大噪声和伪影,还会导致曝光不足/曝光过度和颜色失真,如图2(a)的放大斑块i和ii所示。

为了对腐败进行建模,我们重新制定了方程。 (1) 分别为 R 和 L 引入扰动项,如

Discussion.

(i) 与之前基于 Retinex 的深度学习方法 [30, 49, 54, 65, 66] 不同,我们的 ORF 估计 ¯L 而不是照明图 L,因为如果 ORF 估计 L,则将获得点亮图像通过逐元素除法 (I./L)。计算机很容易受到此操作的影响。张量的值可以非常小(有时甚至等于 0)。除法很容易造成数据溢出问题。此外,计算机随机产生的小误差也会被该运算放大,导致估计不准确。因此,建模 ¯L 更加稳健。 (ii) 以前基于 Retinex 的深度学习方法主要集中于抑制反射图像上的噪声等损坏,即等式 12中的 ˆR。他们忽略了照明图上的估计误差,即等式中的 ^L。 (2),容易导致点亮过程中曝光不足/过度以及颜色失真。相比之下,我们的 ORF 考虑了所有这些损坏并使用 R 来恢复它们。

3.2. Illumination-Guided Transformer

以前的深度学习方法主要依赖于 CNN,在捕获长程依赖性方面表现出局限性。由于全局多头自注意力(MSA)的巨大计算复杂性,一些 CNN-Transformer 混合工作(例如 SNR-Net [57])仅在 U 形 CNN 的最低分辨率下使用全局 Transformer 层。 Transformer 的潜力尚未得到充分挖掘。为了填补这一空白,我们设计了一个照明引导变压器(IGT)来扮演等式(5)中的损坏恢复器 R 的角色。

Network Structure.

IG-MSA

Complexity Analysis.

4. Experiment

5. Conclusion

在本文中,我们提出了一种基于 Transformer 的新颖方法 Retinexformer,用于低光图像增强。我们从Retinex理论开始。通过分析隐藏在曝光不足的场景中以及由点亮过程引起的损坏,我们将扰动项引入到原始的Retinex模型中,并制定了一个新的基于Retinex的框架ORF。然后,我们设计了一个 IGT,利用 ORF 捕获的照明信息来指导不同照明条件下区域的远程依赖性和相互作用的建模。最后,我们的 Retinexformer 是通过将 IGT 插入 ORF 得到的。大量的定量和定性实验表明,我们的 Retinexformer 在 13 个数据集上的表现显着优于 SOTA 方法。用户研究和弱光检测的结果也证明了我们方法的实用价值。

代码解读

Illumination_Estimator

class Illumination_Estimator(nn.Module):
    def __init__(
            self, n_fea_middle, n_fea_in=4, n_fea_out=3):  #__init__部分是内部属性,而forward的输入才是外部输入
        super(Illumination_Estimator, self).__init__()

        self.conv1 = nn.Conv2d(n_fea_in, n_fea_middle, kernel_size=1, bias=True)

        self.depth_conv = nn.Conv2d(
            n_fea_middle, n_fea_middle, kernel_size=5, padding=2, bias=True, groups=n_fea_in)

        self.conv2 = nn.Conv2d(n_fea_middle, n_fea_out, kernel_size=1, bias=True)

    def forward(self, img):
        # img:        b,c=3,h,w
        # mean_c:     b,c=1,h,w
        
        # illu_fea:   b,c,h,w
        # illu_map:   b,c=3,h,w
        
        mean_c = img.mean(dim=1).unsqueeze(1)
        # stx()
        input = torch.cat([img,mean_c], dim=1)

        x_1 = self.conv1(input)
        illu_fea = self.depth_conv(x_1)
        illu_map = self.conv2(illu_fea)
        return illu_fea, illu_map

输入img三维,mean_c是一维,cat后为4维,n_fea_in=4,fea是经过两次conv得到,map是三次conv得到。

IG_MSA

class IG_MSA(nn.Module):
    def __init__(
            self,
            dim,
            dim_head=64,
            heads=8,
    ):
        super().__init__()
        self.num_heads = heads
        self.dim_head = dim_head
        self.to_q = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_k = nn.Linear(dim, dim_head * heads, bias=False)
        self.to_v = nn.Linear(dim, dim_head * heads, bias=False)
        self.rescale = nn.Parameter(torch.ones(heads, 1, 1))
        self.proj = nn.Linear(dim_head * heads, dim, bias=True)
        self.pos_emb = nn.Sequential(
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
            GELU(),
            nn.Conv2d(dim, dim, 3, 1, 1, bias=False, groups=dim),
        )
        self.dim = dim

    def forward(self, x_in, illu_fea_trans):
        """
        x_in: [b,h,w,c]         # input_feature
        illu_fea: [b,h,w,c]         # mask shift? 为什么是 b, h, w, c?
        return out: [b,h,w,c]
        """
        b, h, w, c = x_in.shape
        x = x_in.reshape(b, h * w, c)
        q_inp = self.to_q(x)
        k_inp = self.to_k(x)
        v_inp = self.to_v(x)
        illu_attn = illu_fea_trans # illu_fea: b,c,h,w -> b,h,w,c
        q, k, v, illu_attn = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads),
                                 (q_inp, k_inp, v_inp, illu_attn.flatten(1, 2)))
        v = v * illu_attn
        # q: b,heads,hw,c
        q = q.transpose(-2, -1)
        k = k.transpose(-2, -1)
        v = v.transpose(-2, -1)
        q = F.normalize(q, dim=-1, p=2)
        k = F.normalize(k, dim=-1, p=2)
        attn = (k @ q.transpose(-2, -1))   # A = K^T*Q
        attn = attn * self.rescale
        attn = attn.softmax(dim=-1)
        x = attn @ v   # b,heads,d,hw
        x = x.permute(0, 3, 1, 2)    # Transpose
        x = x.reshape(b, h * w, self.num_heads * self.dim_head)
        out_c = self.proj(x).view(b, h, w, c)
        out_p = self.pos_emb(v_inp.reshape(b, h, w, c).permute(
            0, 3, 1, 2)).permute(0, 2, 3, 1)
        out = out_c + out_p

        return out

q、k、v通过rearrange和transpose,尺寸为b heads hw c,attn是计算q、k的通道维度注意力,尺寸为b heads c c,因此降低了计算量。

FeedForward

class FeedForward(nn.Module):
    def __init__(self, dim, mult=4):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1, 1, bias=False),
            GELU(),
            nn.Conv2d(dim * mult, dim * mult, 3, 1, 1,
                      bias=False, groups=dim * mult),
            GELU(),
            nn.Conv2d(dim * mult, dim, 1, 1, bias=False),
        )

    def forward(self, x):
        """
        x: [b,h,w,c]
        return out: [b,h,w,c]
        """
        out = self.net(x.permute(0, 3, 1, 2))
        return out.permute(0, 2, 3, 1)

卷积+GELU激活

Denoiser

class Denoiser(nn.Module):
    def __init__(self, in_dim=3, out_dim=3, dim=31, level=2, num_blocks=[2, 4, 4]):
        super(Denoiser, self).__init__()
        self.dim = dim
        self.level = level

        # Input projection
        self.embedding = nn.Conv2d(in_dim, self.dim, 3, 1, 1, bias=False)

        # Encoder
        self.encoder_layers = nn.ModuleList([])
        dim_level = dim
        for i in range(level):
            self.encoder_layers.append(nn.ModuleList([
                IGAB(
                    dim=dim_level, num_blocks=num_blocks[i], dim_head=dim, heads=dim_level // dim),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False),
                nn.Conv2d(dim_level, dim_level * 2, 4, 2, 1, bias=False)
            ]))
            dim_level *= 2

        # Bottleneck
        self.bottleneck = IGAB(
            dim=dim_level, dim_head=dim, heads=dim_level // dim, num_blocks=num_blocks[-1])

        # Decoder
        self.decoder_layers = nn.ModuleList([])
        for i in range(level):
            self.decoder_layers.append(nn.ModuleList([
                nn.ConvTranspose2d(dim_level, dim_level // 2, stride=2,
                                   kernel_size=2, padding=0, output_padding=0),
                nn.Conv2d(dim_level, dim_level // 2, 1, 1, bias=False),
                IGAB(
                    dim=dim_level // 2, num_blocks=num_blocks[level - 1 - i], dim_head=dim,
                    heads=(dim_level // 2) // dim),
            ]))
            dim_level //= 2

        # Output projection
        self.mapping = nn.Conv2d(self.dim, out_dim, 3, 1, 1, bias=False)

        # activation function
        self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, illu_fea):
        """
        x:          [b,c,h,w]         x是feature, 不是image
        illu_fea:   [b,c,h,w]
        return out: [b,c,h,w]
        """

        # Embedding
        fea = self.embedding(x)

        # Encoder
        fea_encoder = []
        illu_fea_list = []
        for (IGAB, FeaDownSample, IlluFeaDownsample) in self.encoder_layers:
            fea = IGAB(fea,illu_fea)  # bchw
            illu_fea_list.append(illu_fea)
            fea_encoder.append(fea)
            fea = FeaDownSample(fea)
            illu_fea = IlluFeaDownsample(illu_fea)

        # Bottleneck
        fea = self.bottleneck(fea,illu_fea)

        # Decoder
        for i, (FeaUpSample, Fution, LeWinBlcok) in enumerate(self.decoder_layers):
            fea = FeaUpSample(fea)
            fea = Fution(
                torch.cat([fea, fea_encoder[self.level - 1 - i]], dim=1))
            illu_fea = illu_fea_list[self.level-1-i]
            fea = LeWinBlcok(fea,illu_fea)

        # Mapping
        out = self.mapping(fea) + x

        return out

embedding是3x3卷积,然后是UNet型上下采样结构,Encoder由IGAB和conv下采样组成,

            fea = IGAB(fea,illu_fea)  # bchw

            fea = FeaDownSample(fea)

illu_fea每次进行一次下采样,illu_fea_list和fea_encoder保存illu_fea和fea,fea用于跳连合并,illu_fea用于Decoder过程IGAB使用。

Bottleneck由单个IGAB组成

Decoder:先进行上采样ConvTranspose2d,将上采样结果和fea_encoder层cat,再进行通道压缩,保持通道数不变,然后做IGAB处理。

mapping通道压回到3,最后加上输入x

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值