【论文精读】Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement

出处

2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)

创新点

1.提出了第一个无参考网络(有待考究),避免了有参考网络中过度拟合的问题,提高了模型的泛化能力。
2.设计了一种特定于图像的曲线,能通过不断迭代,实现图像像素级增强的目的。
3.提出了适用于无参考网络的损失函数设计。

光照增强曲线设计

1.光照增强曲线设计应满足三个条件:

1)增强图像的每个像素值应保持在[0,-1]范围内;
2)曲线应单调,保持相邻像素之间的对比度;
3)曲线应该可微(适用于梯度反向传播)。

2.光照增强曲线

设计满足以上三个条件的曲线,公式如下:
L E ( I ( x ) ; α ) = I ( x ) + α I ( x ) ( 1 − I ( x ) ) (1) LE(I(\bold{x});\alpha)=I(\bold{x})+\alpha I(\bold{x})(1-I(\bold{x}))\tag{1} LE(I(x);α)=I(x)+αI(x)(1I(x))(1)
其中 I ( x ) I(x) I(x)表示输入图像, α \alpha α是可训练参数。随着 α \alpha α不同,可实现图像对比度范围调整。下图是不同 α \alpha α值的 L E LE LE曲线示意图:
在这里插入图片描述
显然,该曲线具有调整图像动态范围的能力。
尽管, L E LE LE曲线具有满足调整图像动态范围的能力,但是利用该曲线调整微光图像存在以下两个问题(仅个人观点):①对极度微光条件,其调整能力弱;②在进行微光增强的同时,会减弱强光部分,造成变换后的图像颜色失真。
该曲线实际只是压缩对比度的一种方法。

3.高阶光照增强曲线

为了解决2中提出的问题,本文在公式 ( 1 ) (1) (1)的基础上,设计了其高阶形式,同样满足1中提出的三个条件。具体公式如下:
L E n ( x ) = L E n − 1 + α n L E n − 1 ( x ) ( 1 − L E n − 1 ( x ) ) (2) LE_n(\bold{x})=LE_{n-1}+\alpha _{n}LE_{n-1}(\bold{x})(1-LE_{n-1}(\bold{x}))\tag{2} LEn(x)=LEn1+αnLEn1(x)(1LEn1(x))(2)
n n n取值为8时,可以处理大多数微光图像增强问题。下图给出了当 n n n取值为8时,不同 α \alpha α值的曲线示意图:
图2
如图所示,高阶曲线具有更强的微光增强能力,并且很好的保留高光部分。

4.像素级图像曲线

由公式 ( 2 ) (2) (2)可知 α \alpha α作用于所有像素,所以高阶曲线是一个全局调整过程。然而全局调整往往过度/不足增强局部区域。为了解决这个问题,将本文 α \alpha α表示为像素级参数,具体公式如下:
L E n ( x ) = L E n − 1 + A n ( x ) L E n − 1 ( x ) ( 1 − L E n − 1 ( x ) ) (3) LE_n(\bold{x})=LE_{n-1}+\mathcal{A} _{n}(\bold{x})LE_{n-1}(\bold{x})(1-LE_{n-1}(\bold{x}))\tag{3} LEn(x)=LEn1+An(x)LEn1(x)(1LEn1(x))(3)
其中 A ( x ) \mathcal{A} (\bold{x}) A(x)与输入图像尺寸一致。
下图给出了经过像素级图像曲线处理后,R,G,B通道结果图。
在这里插入图片描述

网络架构

该网络以UNet为主干网络,从中提取具有24通道的特征图像。将24通道按RGB划分8份,作为曲线参数 A 1 , A 2 , . . . A 8 \mathcal{A}_1,\mathcal{A}_2,...\mathcal{A}_8 A1,A2,...A8。根据公式 ( 3 ) (3) (3)对输入图像,迭代8次运算得到增强图像。

1.DCE-Net网络结构

在这里插入图片描述
如图所示,DCE-Net包含七个具有对称跳跃连接的卷积层。在前6个卷积层中,每个卷积层由32个大小为3×3、步长为1的卷积核组成,其后是RELU激活函数。最后一个卷积层由24个大小为3×3、步长为1的卷积核组成,其后是Tanh激活函数,该激活函数为8次迭代产生24个曲线参数映射,其中每次迭代需要3个通道(即,RGB通道)的3个曲线参数映射。
代码如下:

class enhance_net_nopool(nn.Module):

	def __init__(self):
		super(enhance_net_nopool, self).__init__()

		self.relu = nn.ReLU(inplace=True)

		number_f = 32
		self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True) 
		self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True) 
		self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
		self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True) 
		self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True) 

		self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
		self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)


		
	def forward(self, x):

		x1 = self.relu(self.e_conv1(x))
		# p1 = self.maxpool(x1)
		x2 = self.relu(self.e_conv2(x1))
		# p2 = self.maxpool(x2)
		x3 = self.relu(self.e_conv3(x2))
		# p3 = self.maxpool(x3)
		x4 = self.relu(self.e_conv4(x3))

		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))#沿行进行拼接
		# x5 = self.upsample(x5)
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

		x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
		r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)


		x = x + r1*(torch.pow(x,2)-x)
		x = x + r2*(torch.pow(x,2)-x)
		x = x + r3*(torch.pow(x,2)-x)
		enhance_image_1 = x + r4*(torch.pow(x,2)-x)		
		x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)		
		x = x + r6*(torch.pow(x,2)-x)	
		x = x + r7*(torch.pow(x,2)-x)
		enhance_image = x + r8*(torch.pow(x,2)-x)
		r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
		return enhance_image_1,enhance_image,r

2.本文主要框架

在这里插入图片描述

损失函数设计

1.空间一致性损失 L s p a L_{spa} Lspa

L s p a = 1 K ∑ i = 1 K ∑ j ∈ Ω ( i ) ( ∣ ( Y i − Y j ) ∣ − ∣ ( I i − I j ) ∣ ) 2 (4) L_{spa}=\frac{1}{K}\sum_{i=1}^K\sum_{j\in\Omega(i)}(|(Y_i-Y_j)|-|(I_i-I_j)|)^2\tag{4} Lspa=K1i=1KjΩ(i)((YiYj)(IiIj))2(4)
其中 K K K是局部区域的个数, Ω ( I ) \Omega(I) Ω(I)是以区域 i i i为中心的四个相邻区域(上、下、左、右), Y Y Y I I I分别表示增强版本和输入图像中局部区域的平均强度值。该文经验性地将局部区域的大小设置为4×4。在其他区域大小的情况下,这种损失是稳定的。
代码如下:

class L_spa(nn.Module):

    def __init__(self):
        super(L_spa, self).__init__()
        # print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
        self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
        self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
        self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
        self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
        self.pool = nn.AvgPool2d(4)
    def forward(self, org , enhance ):
        b,c,h,w = org.shape

        org_mean = torch.mean(org,1,keepdim=True)
        enhance_mean = torch.mean(enhance,1,keepdim=True)

        org_pool =  self.pool(org_mean)			
        enhance_pool = self.pool(enhance_mean)	

        weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)


        D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
        D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
        D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
        D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)

        D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
        D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
        D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
        D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)

        D_left = torch.pow(D_org_letf - D_enhance_letf,2)
        D_right = torch.pow(D_org_right - D_enhance_right,2)
        D_up = torch.pow(D_org_up - D_enhance_up,2)
        D_down = torch.pow(D_org_down - D_enhance_down,2)
        E = (D_left + D_right + D_up +D_down)
        # E = 25*(D_left + D_right + D_up +D_down)

        return E

2.曝光控制损失 L e x p L_{exp} Lexp

L e x p = 1 K ∑ k = 1 M ∣ Y k − E ∣ (5) L_{exp}=\frac{1}{K}\sum_{k=1}^M|Y_k-E|\tag{5} Lexp=K1k=1MYkE(5)
曝光控制损失衡量局部区域的平均强度值与曝光良好的级别E之间的距离。其中E设置为RGB颜色空间中的灰度级别。在实验中,将E设置为0.6,本文作者指出没有发现将E设置在[0.4,0.7]范围内会有太大的性能差异。
代码如下:

class L_exp(nn.Module):

    def __init__(self,patch_size,mean_val):
        super(L_exp, self).__init__()
        # print(1)
        self.pool = nn.AvgPool2d(patch_size)
        self.mean_val = mean_val
    def forward(self, x ):

        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        mean = self.pool(x)

        d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
        return d

3.颜色稳定性损失 L c o l L_{col} Lcol

遵循Gray-World颜色恒定假设,即每个传感器通道的颜色在整个图像上平均为灰色,设计了颜色恒定损失来校正增强图像中潜在的颜色偏差,并建立了三个通道关系进行调整。颜色恒定损失 L c o l L_{col} Lcol值可以表示为:
L c o l = ∑ ∀ ( p , q ) ∈ ε ( J p − J q ) 2 , ε = ( R , G ) , ( R , B ) , ( G , B ) (6) L_{col}=\sum_{\forall(p,q)\in\varepsilon}(J^p-J^q)^2,\varepsilon={{(R,G),(R,B),(G,B)}}\tag{6} Lcol=(p,q)ε(JpJq)2,ε=(R,G),(R,B),(G,B)(6)
其中 J p J^p Jp表示增强图像中 p p p通道的平均强度值, ( p , q ) (p,q) (pq)表示一对通道。
代码如下:

class L_color(nn.Module):

    def __init__(self):
        super(L_color, self).__init__()

    def forward(self, x ):

        b,c,h,w = x.shape

        mean_rgb = torch.mean(x,[2,3],keepdim=True)
        mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
        Drg = torch.pow(mr-mg,2)
        Drb = torch.pow(mr-mb,2)
        Dgb = torch.pow(mb-mg,2)
        k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)


        return k

4.照明平滑度损失 L t v A L_{tv_\mathcal{A}} LtvA

为了保持相邻像素之间的单调关系,在每个曲线参数图 A \mathcal{A} A中添加一个光照平滑度损失。光照平滑度损失 L t v A L_{tv_\mathcal{A}} LtvA定义为:
L t v A = 1 N ∑ i = 1 N ∑ c ∈ ξ ( ∣ ∇ x A n c ∣ + ∇ y A n c ) 2 , ξ = R , G , B (7) L_{tv_\Alpha}=\frac{1}{N}\sum_{i=1}^N\sum_{c\in\xi}(|\nabla_x\mathcal{A}_n^c|+\nabla_y\mathcal{A}_n^c)^2,\xi={{R,G,B}}\tag{7} LtvA=N1i=1Ncξ(xAnc+yAnc)2,ξ=R,G,B(7)
其中 N N N表示迭代次数, ∇ x \nabla_x x ∇ y \nabla_y y表示水平和垂直方向上的梯度。
代码如下:

class L_TV(nn.Module):
    def __init__(self,TVLoss_weight=1):
        super(L_TV,self).__init__()
        self.TVLoss_weight = TVLoss_weight

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
        return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size

5.总损失

L t o t a l = L s p a + L e x p + W c o l L c o l + W t v A L t v A (8) L_{total}=L_{spa}+L_{exp}+W_{col}L_{col}+W_{tv_\mathcal{A}L_{tv_{\mathcal{A}}}}\tag{8} Ltotal=Lspa+Lexp+WcolLcol+WtvALtvA(8)
其中 W c o l W_{col} Wcol W t v A W_{tv_{\mathcal{A}}} WtvA为损失的权重。

实验结果

在这里插入图片描述

  • 1
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值