目录
出处
2020 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)
创新点
1.提出了第一个无参考网络(有待考究),避免了有参考网络中过度拟合的问题,提高了模型的泛化能力。
2.设计了一种特定于图像的曲线,能通过不断迭代,实现图像像素级增强的目的。
3.提出了适用于无参考网络的损失函数设计。
光照增强曲线设计
1.光照增强曲线设计应满足三个条件:
1)增强图像的每个像素值应保持在[0,-1]范围内;
2)曲线应单调,保持相邻像素之间的对比度;
3)曲线应该可微(适用于梯度反向传播)。
2.光照增强曲线
设计满足以上三个条件的曲线,公式如下:
L
E
(
I
(
x
)
;
α
)
=
I
(
x
)
+
α
I
(
x
)
(
1
−
I
(
x
)
)
(1)
LE(I(\bold{x});\alpha)=I(\bold{x})+\alpha I(\bold{x})(1-I(\bold{x}))\tag{1}
LE(I(x);α)=I(x)+αI(x)(1−I(x))(1)
其中
I
(
x
)
I(x)
I(x)表示输入图像,
α
\alpha
α是可训练参数。随着
α
\alpha
α不同,可实现图像对比度范围调整。下图是不同
α
\alpha
α值的
L
E
LE
LE曲线示意图:
显然,该曲线具有调整图像动态范围的能力。
尽管,
L
E
LE
LE曲线具有满足调整图像动态范围的能力,但是利用该曲线调整微光图像存在以下两个问题(仅个人观点):①对极度微光条件,其调整能力弱;②在进行微光增强的同时,会减弱强光部分,造成变换后的图像颜色失真。
该曲线实际只是压缩对比度的一种方法。
3.高阶光照增强曲线
为了解决2中提出的问题,本文在公式
(
1
)
(1)
(1)的基础上,设计了其高阶形式,同样满足1中提出的三个条件。具体公式如下:
L
E
n
(
x
)
=
L
E
n
−
1
+
α
n
L
E
n
−
1
(
x
)
(
1
−
L
E
n
−
1
(
x
)
)
(2)
LE_n(\bold{x})=LE_{n-1}+\alpha _{n}LE_{n-1}(\bold{x})(1-LE_{n-1}(\bold{x}))\tag{2}
LEn(x)=LEn−1+αnLEn−1(x)(1−LEn−1(x))(2)
当
n
n
n取值为8时,可以处理大多数微光图像增强问题。下图给出了当
n
n
n取值为8时,不同
α
\alpha
α值的曲线示意图:
如图所示,高阶曲线具有更强的微光增强能力,并且很好的保留高光部分。
4.像素级图像曲线
由公式
(
2
)
(2)
(2)可知
α
\alpha
α作用于所有像素,所以高阶曲线是一个全局调整过程。然而全局调整往往过度/不足增强局部区域。为了解决这个问题,将本文
α
\alpha
α表示为像素级参数,具体公式如下:
L
E
n
(
x
)
=
L
E
n
−
1
+
A
n
(
x
)
L
E
n
−
1
(
x
)
(
1
−
L
E
n
−
1
(
x
)
)
(3)
LE_n(\bold{x})=LE_{n-1}+\mathcal{A} _{n}(\bold{x})LE_{n-1}(\bold{x})(1-LE_{n-1}(\bold{x}))\tag{3}
LEn(x)=LEn−1+An(x)LEn−1(x)(1−LEn−1(x))(3)
其中
A
(
x
)
\mathcal{A} (\bold{x})
A(x)与输入图像尺寸一致。
下图给出了经过像素级图像曲线处理后,R,G,B通道结果图。
网络架构
该网络以UNet为主干网络,从中提取具有24通道的特征图像。将24通道按RGB划分8份,作为曲线参数 A 1 , A 2 , . . . A 8 \mathcal{A}_1,\mathcal{A}_2,...\mathcal{A}_8 A1,A2,...A8。根据公式 ( 3 ) (3) (3)对输入图像,迭代8次运算得到增强图像。
1.DCE-Net网络结构
如图所示,DCE-Net包含七个具有对称跳跃连接的卷积层。在前6个卷积层中,每个卷积层由32个大小为3×3、步长为1的卷积核组成,其后是RELU激活函数。最后一个卷积层由24个大小为3×3、步长为1的卷积核组成,其后是Tanh激活函数,该激活函数为8次迭代产生24个曲线参数映射,其中每次迭代需要3个通道(即,RGB通道)的3个曲线参数映射。
代码如下:
class enhance_net_nopool(nn.Module):
def __init__(self):
super(enhance_net_nopool, self).__init__()
self.relu = nn.ReLU(inplace=True)
number_f = 32
self.e_conv1 = nn.Conv2d(3,number_f,3,1,1,bias=True)
self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
self.e_conv7 = nn.Conv2d(number_f*2,24,3,1,1,bias=True)
self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
def forward(self, x):
x1 = self.relu(self.e_conv1(x))
# p1 = self.maxpool(x1)
x2 = self.relu(self.e_conv2(x1))
# p2 = self.maxpool(x2)
x3 = self.relu(self.e_conv3(x2))
# p3 = self.maxpool(x3)
x4 = self.relu(self.e_conv4(x3))
x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))#沿行进行拼接
# x5 = self.upsample(x5)
x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))
r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
x = x + r1*(torch.pow(x,2)-x)
x = x + r2*(torch.pow(x,2)-x)
x = x + r3*(torch.pow(x,2)-x)
enhance_image_1 = x + r4*(torch.pow(x,2)-x)
x = enhance_image_1 + r5*(torch.pow(enhance_image_1,2)-enhance_image_1)
x = x + r6*(torch.pow(x,2)-x)
x = x + r7*(torch.pow(x,2)-x)
enhance_image = x + r8*(torch.pow(x,2)-x)
r = torch.cat([r1,r2,r3,r4,r5,r6,r7,r8],1)
return enhance_image_1,enhance_image,r
2.本文主要框架
损失函数设计
1.空间一致性损失 L s p a L_{spa} Lspa
L
s
p
a
=
1
K
∑
i
=
1
K
∑
j
∈
Ω
(
i
)
(
∣
(
Y
i
−
Y
j
)
∣
−
∣
(
I
i
−
I
j
)
∣
)
2
(4)
L_{spa}=\frac{1}{K}\sum_{i=1}^K\sum_{j\in\Omega(i)}(|(Y_i-Y_j)|-|(I_i-I_j)|)^2\tag{4}
Lspa=K1i=1∑Kj∈Ω(i)∑(∣(Yi−Yj)∣−∣(Ii−Ij)∣)2(4)
其中
K
K
K是局部区域的个数,
Ω
(
I
)
\Omega(I)
Ω(I)是以区域
i
i
i为中心的四个相邻区域(上、下、左、右),
Y
Y
Y和
I
I
I分别表示增强版本和输入图像中局部区域的平均强度值。该文经验性地将局部区域的大小设置为4×4。在其他区域大小的情况下,这种损失是稳定的。
代码如下:
class L_spa(nn.Module):
def __init__(self):
super(L_spa, self).__init__()
# print(1)kernel = torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
self.weight_left = nn.Parameter(data=kernel_left, requires_grad=False)
self.weight_right = nn.Parameter(data=kernel_right, requires_grad=False)
self.weight_up = nn.Parameter(data=kernel_up, requires_grad=False)
self.weight_down = nn.Parameter(data=kernel_down, requires_grad=False)
self.pool = nn.AvgPool2d(4)
def forward(self, org , enhance ):
b,c,h,w = org.shape
org_mean = torch.mean(org,1,keepdim=True)
enhance_mean = torch.mean(enhance,1,keepdim=True)
org_pool = self.pool(org_mean)
enhance_pool = self.pool(enhance_mean)
weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
D_org_letf = F.conv2d(org_pool , self.weight_left, padding=1)
D_org_right = F.conv2d(org_pool , self.weight_right, padding=1)
D_org_up = F.conv2d(org_pool , self.weight_up, padding=1)
D_org_down = F.conv2d(org_pool , self.weight_down, padding=1)
D_enhance_letf = F.conv2d(enhance_pool , self.weight_left, padding=1)
D_enhance_right = F.conv2d(enhance_pool , self.weight_right, padding=1)
D_enhance_up = F.conv2d(enhance_pool , self.weight_up, padding=1)
D_enhance_down = F.conv2d(enhance_pool , self.weight_down, padding=1)
D_left = torch.pow(D_org_letf - D_enhance_letf,2)
D_right = torch.pow(D_org_right - D_enhance_right,2)
D_up = torch.pow(D_org_up - D_enhance_up,2)
D_down = torch.pow(D_org_down - D_enhance_down,2)
E = (D_left + D_right + D_up +D_down)
# E = 25*(D_left + D_right + D_up +D_down)
return E
2.曝光控制损失 L e x p L_{exp} Lexp
L
e
x
p
=
1
K
∑
k
=
1
M
∣
Y
k
−
E
∣
(5)
L_{exp}=\frac{1}{K}\sum_{k=1}^M|Y_k-E|\tag{5}
Lexp=K1k=1∑M∣Yk−E∣(5)
曝光控制损失衡量局部区域的平均强度值与曝光良好的级别E之间的距离。其中E设置为RGB颜色空间中的灰度级别。在实验中,将E设置为0.6,本文作者指出没有发现将E设置在[0.4,0.7]范围内会有太大的性能差异。
代码如下:
class L_exp(nn.Module):
def __init__(self,patch_size,mean_val):
super(L_exp, self).__init__()
# print(1)
self.pool = nn.AvgPool2d(patch_size)
self.mean_val = mean_val
def forward(self, x ):
b,c,h,w = x.shape
x = torch.mean(x,1,keepdim=True)
mean = self.pool(x)
d = torch.mean(torch.pow(mean- torch.FloatTensor([self.mean_val] ).cuda(),2))
return d
3.颜色稳定性损失 L c o l L_{col} Lcol
遵循Gray-World颜色恒定假设,即每个传感器通道的颜色在整个图像上平均为灰色,设计了颜色恒定损失来校正增强图像中潜在的颜色偏差,并建立了三个通道关系进行调整。颜色恒定损失
L
c
o
l
L_{col}
Lcol值可以表示为:
L
c
o
l
=
∑
∀
(
p
,
q
)
∈
ε
(
J
p
−
J
q
)
2
,
ε
=
(
R
,
G
)
,
(
R
,
B
)
,
(
G
,
B
)
(6)
L_{col}=\sum_{\forall(p,q)\in\varepsilon}(J^p-J^q)^2,\varepsilon={{(R,G),(R,B),(G,B)}}\tag{6}
Lcol=∀(p,q)∈ε∑(Jp−Jq)2,ε=(R,G),(R,B),(G,B)(6)
其中
J
p
J^p
Jp表示增强图像中
p
p
p通道的平均强度值,
(
p
,
q
)
(p,q)
(p,q)表示一对通道。
代码如下:
class L_color(nn.Module):
def __init__(self):
super(L_color, self).__init__()
def forward(self, x ):
b,c,h,w = x.shape
mean_rgb = torch.mean(x,[2,3],keepdim=True)
mr,mg, mb = torch.split(mean_rgb, 1, dim=1)
Drg = torch.pow(mr-mg,2)
Drb = torch.pow(mr-mb,2)
Dgb = torch.pow(mb-mg,2)
k = torch.pow(torch.pow(Drg,2) + torch.pow(Drb,2) + torch.pow(Dgb,2),0.5)
return k
4.照明平滑度损失 L t v A L_{tv_\mathcal{A}} LtvA
为了保持相邻像素之间的单调关系,在每个曲线参数图
A
\mathcal{A}
A中添加一个光照平滑度损失。光照平滑度损失
L
t
v
A
L_{tv_\mathcal{A}}
LtvA定义为:
L
t
v
A
=
1
N
∑
i
=
1
N
∑
c
∈
ξ
(
∣
∇
x
A
n
c
∣
+
∇
y
A
n
c
)
2
,
ξ
=
R
,
G
,
B
(7)
L_{tv_\Alpha}=\frac{1}{N}\sum_{i=1}^N\sum_{c\in\xi}(|\nabla_x\mathcal{A}_n^c|+\nabla_y\mathcal{A}_n^c)^2,\xi={{R,G,B}}\tag{7}
LtvA=N1i=1∑Nc∈ξ∑(∣∇xAnc∣+∇yAnc)2,ξ=R,G,B(7)
其中
N
N
N表示迭代次数,
∇
x
\nabla_x
∇x和
∇
y
\nabla_y
∇y表示水平和垂直方向上的梯度。
代码如下:
class L_TV(nn.Module):
def __init__(self,TVLoss_weight=1):
super(L_TV,self).__init__()
self.TVLoss_weight = TVLoss_weight
def forward(self,x):
batch_size = x.size()[0]
h_x = x.size()[2]
w_x = x.size()[3]
count_h = (x.size()[2]-1) * x.size()[3]
count_w = x.size()[2] * (x.size()[3] - 1)
h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum()
w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum()
return self.TVLoss_weight*2*(h_tv/count_h+w_tv/count_w)/batch_size
5.总损失
L
t
o
t
a
l
=
L
s
p
a
+
L
e
x
p
+
W
c
o
l
L
c
o
l
+
W
t
v
A
L
t
v
A
(8)
L_{total}=L_{spa}+L_{exp}+W_{col}L_{col}+W_{tv_\mathcal{A}L_{tv_{\mathcal{A}}}}\tag{8}
Ltotal=Lspa+Lexp+WcolLcol+WtvALtvA(8)
其中
W
c
o
l
W_{col}
Wcol和
W
t
v
A
W_{tv_{\mathcal{A}}}
WtvA为损失的权重。