【低照度图像增强系列(6)】IceNet算法详解与代码实现(IEEE)

本文介绍了IceNet,一种基于CNN的交互式图像增强算法,针对低光照条件下的目标检测问题。通过用户提供的参数和局部注解,IceNet自适应生成增强图像,包含交互式亮度控制、熵损失、平滑损失等模块。算法能显著提高在暗环境下目标检测的精度。
摘要由CSDN通过智能技术生成

前言 

☀️ 在低照度场景下进行目标检测任务,常存在图像RGB特征信息少提取特征困难目标识别和定位精度低等问题,给检测带来一定的难度。

     🌻使用图像增强模块对原始图像进行画质提升,恢复各类图像信息,再使用目标检测网络对增强图像进行特定目标检测,有效提高检测的精确度。

      ⭐本专栏会介绍传统方法、Retinex、EnlightenGAN、SCI、Zero-DCE、IceNet、RRDNet、URetinex-Net等低照度图像增强算法。

👑完整代码已打包上传至资源→低照度图像增强代码汇总资源

目录

前言 

🚀一、IceNet介绍  

☀️1.1 IceNet简介  

🚀二、IceNet网络结构及核心代码

☀️2.1 网络结构 

☀️2.2 核心代码 

🚀三、IceNet损失函数及核心代码

☀️3.1 Interactive brightness control loss—交互式亮度控制损失

☀️3.2  Entropy loss—熵损失

☀️3.3 Smoothness loss—平滑损失

☀️3.4 Total loss—总损失

🚀四、IceNet代码复现

☀️4.1 环境配置

☀️4.2 运行过程

☀️4.3 运行效果

🚀一、IceNet介绍  

相关资料: 

☀️1.1 IceNet简介  

本文提出了一种基于 CNN 的交互式对比度增强算法,称为 IceNet,该算法使用户能够根据自己的喜好轻松调整图像对比度。

具体来说,用户提供用于控制全局亮度的参数两种类型的scribble来使图像中的局部区域变暗或变亮。然后,根据这些注释,IceNet 估计用于逐像素伽玛校正的伽玛图。最后,通过色彩恢复,得到增强后的图像。用户可以迭代地提供注释以获得满意的图像。

IceNet还能够自动生成个性化的增强图像,如果需要的话可以作为进一步调整的基础。

本文主要贡献

  • 本文提出了第一个基于CNN的交互式CE算法,称为IceNet,它可以根据用户的偏好,自适应地生成增强的图像,也可以自动生成无需交互的图像。
  • 本文用提出的三个可微损失函数训练IceNet,从而实现与用户的交互,并产生效果不错的增强图像。
  • 本文通过各种实验结果,证明IceNet可以为用户提供满意的结果,明显优于传统算法。


🚀二、IceNet网络结构及核心代码

☀️2.1 网络结构 

通过检查输入图像I,用户提供一个曝光等级η∈[0,1]来控制全局亮度和两种scribble类型(红色蓝色scribble分别表示用户想要使相应的局部区域变暗或变亮。步骤如下:

  • 首先,在scribble图S中分别记为−1和1,其余像素赋值为0
  • 接着,将RGB彩色图像I转换到YCbCr空间只调整亮度分量Y,同时保留色度分量。
  • 然后,估计一张伽马图Γ,用于y的像素级伽马校正。
  • 最后,通过颜色恢复,得到增强后的图像J

☀️2.2 核心代码 

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class IceNet(nn.Module):

	def __init__(self):
		super(IceNet, self).__init__()

		self.relu = nn.ReLU(inplace=True)
		# 7个卷积层用于提取特征
		self.e_conv1 = nn.Conv2d(2,32,3,1,1,bias=True) 
		self.e_conv2 = nn.Conv2d(32,32,3,1,1,bias=True) 
		self.e_conv3 = nn.Conv2d(32,32,3,1,1,bias=True) 
		self.e_conv4 = nn.Conv2d(32,32,3,1,1,bias=True) 
		self.e_conv5 = nn.Conv2d(64,32,3,1,1,bias=True) 
		self.e_conv6 = nn.Conv2d(64,32,3,1,1,bias=True) 
		self.e_conv7 = nn.Conv2d(64,32,3,1,1,bias=True)
        # 两个全连接层 (fc1 和 fc2),用于生成自适应向量。
		self.fc1 = nn.Linear(1, 32)
		self.fc2 = nn.Linear(32, 32)


	def forward(self, y, maps, e, lowlight=None, is_train=False):
		b, _, h, w = y.shape
		x_ = torch.cat([y, maps], 1) # y 和 maps 是输入的张量

		# generate adaptive vector according to eta
		W = self.relu(self.fc1(e)) # e是一个向量,用于生成自适应增强的参数。
		W = self.fc2(W)

		# feature extractor
		x1 = self.relu(self.e_conv1(x_))
		x2 = self.relu(self.e_conv2(x1))
		x3 = self.relu(self.e_conv3(x2))
		x4 = self.relu(self.e_conv4(x3))
		x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
		x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
		x_r = self.relu(self.e_conv7(torch.cat([x1,x6],1)))

		# 使用自适应增强方法 AGEB 对 x_r 进行增强。
		x_r = F.conv2d(x_r.view(1, b * 32, h, w),
				W.view(b, 32, 1, 1), groups=b)
		x_r = torch.sigmoid(x_r).view(b, 1, h, w) * 10

		# 进行 gamma 校正,得到增强后的图像 enhanced_Y。
		enhanced_Y = torch.pow(y,x_r)
		if is_train:
			return enhanced_Y, x_r # 如果处于训练模式,返回增强后的图像和增强参数 x_r;
		else:
			# color restoration
			enhanced_image = torch.clip(enhanced_Y*(lowlight/y), 0, 1)
			return enhanced_image # 否则,对增强后的图像进行颜色还原,得到 enhanced_image,并返回。

 步骤如下:

  • 首先,对输入的 y maps 进行concat,生成新的输入 x_
  • 然后,通过全连接层计算自适应向量 W,并使用 ReLU 激活函数。
  • 接下来,通过多个卷积层提取特征,得到 x_r。再使用 AGEB 方法对 x_r 进行自适应增强。
  • 最后,对增强后的图像进行 gamma 校正,得到 enhanced_Y
  • 如果是训练模式,则返回增强后的图像和增强参数 x_r;否则,对增强后的图像进行颜色还原,得到 enhanced_image,并返回。

🚀三、IceNet损失函数及核心代码

这篇文章的贡献也是,主要在损失函数上。

☀️3.1 Interactive brightness control loss—交互式亮度控制损失

  • 首先,在输入亮度Y上加上scribble的S
  • 接着,将Y归一化到[0,1]的范围,得到

  • 最后,使用双边伽玛调整方案来提高细节在暗和亮区域的可见性

  • 式(7):暗区增强
  • 式(8):亮区增强
  • 式(9):将这两种结果结合起来,同时保留暗区和亮区细节  

☀️3.2  Entropy loss—熵损失

设计目的

最大熵是通过均匀分布实现的,因此通过均衡输出图像的直方图,采用熵损失来增加全局对比度。
直方图由于其不可微性,不能直接使用。所以本文设计了一个软直方图。

图4(a)显示了σ = 5、10或20时的软映射函数。

  • σ控制映射函数的宽度和高度之间的权衡。随着σ的增大,映射函数变窄变高。
  • 在本文中,设定σ = 10。

通过对所有像素的贡献求和,得到软直方图

定义为熵的倒数


☀️3.3 Smoothness loss—平滑损失

引入目的: 为了促进式(1)中伽马图Γ中相邻值之间的平滑变化


☀️3.4 Total loss—总损失

总损失定义为三种损失的加权和

  • 首先,使IceNet能够控制全局和局部亮度。
  • 其次,促进一个平滑的直方图的形成,这可以增加整体对比度。
  • 第三,平滑伽马图。

图4(b) ~ (e)说明了每一次损失的效果

 代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torchvision.models.vgg import vgg16
import numpy as np
from numpy.testing import assert_almost_equal

class L_ent(nn.Module):
    def __init__(self, bins, min, max, sigma):
        super(L_ent, self).__init__()
        self.bins = bins # 表示直方图的分箱数量
        self.min = min # 表示直方图的最小值
        self.max = max # 表示直方图的最大值
        self.sigma = sigma # 超参数
        self.delta = float(max - min) / float(bins) # 计算直方图分箱的间隔大小
        self.centers = float(min) + self.delta * (torch.arange(bins).float().cuda() + 0.5) # 计算直方图的中心

    def forward(self, y):
        b, _, h, w = y.shape # 获取输入张量 y 的形状信息
        y = y.reshape(b, 1, -1) # -1 表示自动推断维度。
        c = self.centers.reshape(1, -1, 1).repeat(b, 1, 1) # 计算直方图的中心c
        x = y - c # 计算差值 x,即每个像素值与直方图中心的距离。
        # (x + self.delta/2) 和 (x - self.delta/2) 分别对应直方图箱的右边界和左边界。
        x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))

        hist = torch.sum(x, 2)
        p = hist / (h * w) + 1e-6 # 计算直方图 hist,并归一化为概率 p。
        d = torch.sum((-p * torch.log(p))) # 计算交叉熵损失 d。
        return 1/d

class L_int(nn.Module):

    def __init__(self):
        super(L_int, self).__init__()

    def forward(self, x, mean_val, labels):
        b,c,h,w = x.shape
        x = torch.mean(x,1,keepdim=True)
        d = torch.mean(torch.pow(x- labels,2)) # 计算平均值与标签之间的均方误差 d

        return d
        
class L_smo(nn.Module):
    def __init__(self):
        super(L_smo,self).__init__()

    def forward(self,x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h =  (x.size()[2]-1) * x.size()[3]
        count_w = x.size()[2] * (x.size()[3] - 1)
        h_tv = torch.pow((x[:,:,1:,:]-x[:,:,:h_x-1,:]),2).sum() # h_tv 表示水平方向上的总变化数
        w_tv = torch.pow((x[:,:,:,1:]-x[:,:,:,:w_x-1]),2).sum() # w_tv 表示垂直方向上的总变化数。
        return 2*(h_tv/count_h+w_tv/count_w)/batch_size # 计算总的平滑损失值,并除以批量大小,得到最终的损失值。

🚀四、IceNet代码复现

☀️4.1 环境配置

  • Python 3.7
  • Pytorch 1.0.0
  • opencv
  • torchvision 0.2.1
  • cuda 10.0

☀️4.2 运行过程

这个也是运行比较简单,配好环境就行 。不再过多叙述~


☀️4.3 运行效果


  • 37
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
以下是一些matlab照度增强的代码和工具: 1. 基于Retinex的增强方法 ```matlab function [output] = retinex(input, sigma) % 基于Retinex的增强方法 % input: 输入图像 % sigma: 高斯核标准差 % output: 增强后的图像 % 参考文献:http://www.ipol.im/pub/art/2014/107/ % 代码来源:https://github.com/yskmt/retinex input = double(input); output = zeros(size(input)); for i = 1:3 log_input = log(input(:,:,i) + 1); log_input = imgaussfilt(log_input, sigma); log_input = log_input - imgaussfilt(log_input, 2*sigma); output(:,:,i) = exp(log_input) - 1; end output = uint8(output); end ``` 2. 基于暗通道先验的增强方法 ```matlab function [output] = dark_channel_prior(input, patch_size, w) % 基于暗通道先验的增强方法 % input: 输入图像 % patch_size: 暗通道先验中的窗口大小 % w: 前景区域比例 % output: 增强后的图像 % 参考文献:http://kaiminghe.com/publications/cvpr09.pdf % 代码来源:https://github.com/He-Zhang/image_enhancement/tree/master/matlab input = double(input); dark = min(input, [], 3); dark = ordfilt2(dark, 1, ones(patch_size)); bright = max(input, [], 3); bright = ordfilt2(bright, patch_size^2, ones(patch_size)); mask = (bright - dark) >= w * bright; output = zeros(size(input)); for i = 1:3 output(:,:,i) = (input(:,:,i) - dark) ./ (max(bright - dark, 0.01)) .* mask + input(:,:,i) .* (1 - mask); end output = uint8(output); end ``` 3. 基于深度学习的增强方法 ```matlab % 参考文献:https://ieeexplore.ieee.org/document/8332287 % 代码来源:https://github.com/cszn/DnCNN/tree/master/testsets/Set12 % 下载预训练模型:https://github.com/cszn/DnCNN/releases/download/v1.0/dncnn_gray_blind.mat net = load('dncnn_gray_blind.mat'); net = net.net; net = vl_simplenn_tidy(net); net.layers(end) = []; net.layers(end) = []; net = vl_simplenn_tidy(net); input = imread('lena.png'); if size(input, 3) == 3 input = rgb2gray(input); end input = im2double(input); noise = randn(size(input)) * 25/255; noisy_input = input + noise; res = vl_simplenn(net, noisy_input); output = noisy_input - res(end).x; output = im2uint8(output); ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

路人贾'ω'

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值