【代码复现Zero-DCE详解:Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement】

在这里插入图片描述

链接概括

1.文章:(CVPR 2020) Zero-Reference Deep Curve Estimation for Low-Light Image Enhancement
2. 链接: paper.
3. 链接: code.
4. 其他博主复现链接: link.

存在的几个主要问题

1.检查电脑是否含有GPU,不是GPU环境需要将代码改为CPU环境下
2. 源代码中没有数据集
3. 代码中路径问题,路径不对将导致无法加载数据
3.需要按照要求设置文件夹
4. 由于版本不同,代码运行中出现的警告影响代码运行

检查电脑是否含有GPU:将代码改为CPU环境下运行

论文中的代码是GPU环境下的,并且是在每一句需要用到GPU的代码下注释的,所以需要将所有含有该语令的改为CPU
具体改法:
将代码中".cuda()"删去,或者将其改为“.cpu()”
下面仅仅是部分示例:其中前面带“#”为源代码

// Myloss.py
        # kernel_left = torch.FloatTensor( [[0,0,0],[-1,1,0],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # kernel_right = torch.FloatTensor( [[0,0,0],[0,1,-1],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # kernel_up = torch.FloatTensor( [[0,-1,0],[0,1, 0 ],[0,0,0]]).cuda().unsqueeze(0).unsqueeze(0)
        # kernel_down = torch.FloatTensor( [[0,0,0],[0,1, 0],[0,-1,0]]).cuda().unsqueeze(0).unsqueeze(0)
        kernel_left = torch.FloatTensor([[0, 0, 0], [-1, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        kernel_right = torch.FloatTensor([[0, 0, 0], [0, 1, -1], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        kernel_up = torch.FloatTensor([[0, -1, 0], [0, 1, 0], [0, 0, 0]]).unsqueeze(0).unsqueeze(0)
        kernel_down = torch.FloatTensor([[0, 0, 0], [0, 1, 0], [0, -1, 0]]).unsqueeze(0).unsqueeze(0)

        # weight_diff =torch.max(torch.FloatTensor([1]).cuda() + 10000*torch.min(org_pool - torch.FloatTensor([0.3]).cuda(),torch.FloatTensor([0]).cuda()),torch.FloatTensor([0.5]).cuda())
        weight_diff =torch.max(torch.FloatTensor([1]) + 10000*torch.min(org_pool - torch.FloatTensor([0.3]),torch.FloatTensor([0])),torch.FloatTensor([0.5]))
        # E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5]).cuda()) ,enhance_pool-org_pool)
        E_1 = torch.mul(torch.sign(enhance_pool - torch.FloatTensor([0.5])), enhance_pool - org_pool)

在CPU条件下 在训练模型中需要将pin_memory=True改为pin_memory=False

//  lowlight_train.py
   train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True)
	# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True,num_workers=config.num_workers, pin_memory=False)

源代码中没有训练数据集

该论文源代码里面有预训练模型,所以没有训练数据集,并不影响直接测试代码。
该Epoch99.pth即为预训练模型。所以训练集并不影响测试。
在这里插入图片描述
当然你也可以根据代码要求寻找数据集加入,并训练出自己的预训练模型

代码中路径问题,路径不对将导致无法加载数据

在dataloader.py文件中 :
作用是为数据做预处理的相关代码。
文件路径出现了“.jpg”,如果训练集中的数据为PNG格式,就需要将最后改为“.png”

// dataloader.py
def populate_train_list(lowlight_images_path):#获取训练列表(微光图像路径)

	# image_list_lowlight = glob.glob(lowlight_images_path + "*.jpg")
	image_list_lowlight = glob.glob(lowlight_images_path + "*.png")

	train_list = image_list_lowlight

	random.shuffle(train_list)

	return train_list

在 lowlight_train.py文件中 :
将文件路径改为自己的数据路径
绝对路径,相对路径都可以
“在开始调试的代码过程中可以更加学习率与减少迭代次数,在代码调通后在改为规定的参数,可以运行更快”

// lowlight_train.py
if __name__ == "__main__":

	parser = argparse.ArgumentParser()

	# Input Parameters
	# parser.add_argument('--lowlight_images_path', type=str, default="data/train_data/")
	parser.add_argument('--lowlight_images_path', type=str, default="E:/image/Zero-DCE-master/Zero-DCE_code/data/train_data/")
	# parser.add_argument('--lr', type=float, default=0.0001)
	# parser.add_argument('--weight_decay', type=float, default=0.0001)
	parser.add_argument('--grad_clip_norm', type=float, default=0.1)
	parser.add_argument('--lr', type=float, default=0.01)
	parser.add_argument('--weight_decay', type=float, default=0.01)
	parser.add_argument('--num_epochs', type=int, default=20)
	# parser.add_argument('--num_epochs', type=int, default=20)
	parser.add_argument('--train_batch_size', type=int, default=8)
	parser.add_argument('--val_batch_size', type=int, default=4)
	parser.add_argument('--num_workers', type=int, default=4)
	parser.add_argument('--display_iter', type=int, default=10)
	parser.add_argument('--snapshot_iter', type=int, default=10)
	parser.add_argument('--snapshots_folder', type=str, default="E:/image/Zero-DCE-master/Zero-DCE_code/snapshots/")
	# parser.add_argument('--snapshots_folder', type=str, default="snapshots/")
	parser.add_argument('--load_pretrain', type=bool, default= False)
	parser.add_argument('--pretrain_dir', type=str, default= "E:/image/Zero-DCE-master/Zero-DCE_code/snapshots/Epoch99.pth")
	# parser.add_argument('--pretrain_dir', type=str, default="snapshots/Epoch99.pth")
	config = parser.parse_args()

需要按照要求设置文件夹

在data文件夹下设置result文件夹,并在result文件下设置与test_data文件夹下一样的两个文件夹:DICM与LIME文件夹
必须一样,否则将报错
在这里插入图片描述

由于版本不同,代码运行中出现的警告影响代码运行

一、 torch.nn.utils.clip_grad_norm函数被弃用

//  torch.nn.utils.clip_grad_norm函数被弃用 的警告
UserWarning: torch.nn.utils.clip_grad_norm is now deprecated in favor of torch.nn.utils.clip_grad_norm_. 
torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)

改法;

//  lowlight_train.py
        # torch.nn.utils.clip_grad_norm(DCE_net.parameters(),config.grad_clip_norm)
   		torch.nn.utils.clip_grad_norm_(DCE_net.parameters(),config.grad_clip_norm)

二、 nn.functional.tanh被弃用

//  nn.functional.tanh被弃用的警告
UserWarning: nn.functional.tanh is deprecated. Use torch.tanh instead.
 warnings.warn("nn.functional.tanh is deprecated. Use torch.tanh instead.")

改法

//  model.py
    x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))

   	# x5 = self.upsample(x5)
   	x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))

   	# x_r = F.tanh(self.e_conv7(torch.cat([x1,x6],1)))  #改之前
   	x_r = torch.tanh(self.e_conv7(torch.cat([x1, x6], 1)))
   	r1,r2,r3,r4,r5,r6,r7,r8 = torch.split(x_r, 3, dim=1)
  • 1
    点赞
  • 33
    收藏
    觉得还不错? 一键收藏
  • 8
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 8
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值