DCPDN复现

参考大佬文章和资料:

https://blog.csdn.net/qq_32734095/article/details/89059949

https://blog.csdn.net/qq_41202069/article/details/106589603

1、环境安装:

​​​​​​​1.1 要求

Anaconda3

Ubuntu 18.04.5 LTS (cat /etc/issue)

Python3.6环境

Pytorch=0.3.1框架

CUDA9.1+CUDNN7.5

1.2 主要过程和问题:

(1)安装Pytorch>=0.3.1等
​​​​​​​

conda create -n pt031 python=3.6

conda activate pt031

pip3 install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp36-cp36m-linux_x86_64.whl

pip install torchvision==0.2.1

pip install h5py

pip install scikit-image

cuda和cudnn安装遇到问题,由于本机安装了cuda10和tf等,折腾两个cuda版本并存,搞坏了pytorch,重新安装:

pip install torch-0.3.1-cp36-cp36m-linux_x86_64.whl  --force-reinstall

(2)安装cuda、cudnn,下载,官网慢,在csdn下载找到可用资源付费下载

安装cuda:

sudo ./cuda_9.1.85_387.26_linux.run (报错)

sudo ./cuda_9.1.85_387.26_linux.run --override #(暂时可用,应该降gcc版本)

安装cudnn:

cp cudnn-9.1-linux-x64-v7.solitairetheme8 cudnn-9.1-linux-x64-v7.tgz

tar -zxvf cudnn-9.1-linux-x64-v7.tgz 

sudo cp cudnn.h /usr/local/cuda-9.1/include/

sudo cp cudnn.h /usr/local/cuda-9.1/include/

两个版本cuda切换:
 

cd /usr/local/

sudo rm -rf cuda #(删除原来的软链接cuda10)

sudo ln -s /usr/local/cuda-9.1 /usr/local/cuda #(创建cuda9.1的软链接)

后期又遇到了其他安装包不兼容的问题,基本百度可以解决

2 复现复现过程参见参考连接,遇到的问题简单记录下(后补充的所以问题无图片截图)

(1)首先根据作者给出的预训练模型测试nat_new4(36个真实有雾图片的h5文件)

python demo.py --dataroot ./facades/nat_new4 --valDataroot ./facades/nat_new4 --netG ./demo_model/netG_epoch_8.pth

主要问题记录:

(a)
if opt.netG != '': netG.load_state_dict(torch.load(opt.netG))
报错,查看后发现模型读入后部分层丢失,tran_dense.dense_block1和tran_dense.dense_block2都不见了,原因不明(pytorch小白,后期还要补充理论知识),按照以下修改后可用,并且层没有丢失 :
if opt.netG != '': netG.load_state_dict(torch.load(opt.netG),strict=False)

(b)

input_cpu, target_cpu, depth_cpu, ato_cpu, imgname = data

报错,需要5个参数只得到4个,查阅分析代码后修改如下,简单粗暴的将文件名称这一个参数去掉,但由于后面打印还需要此参数,故将此参数用i来表示,目前跑代码没有影响:

input_cpu, target_cpu, depth_cpu, ato_cpu, imgname = data
imgname = i

打印部分也相应修改

vutils.save_image(zz1, './result_cvpr18/image/real_dehazed/'+imgname[index2]+'_DCPCN.png', normalize=True, scale_each=False)
改为:
vutils.save_image(zz1, './result_cvpr18/image/real_dehazed/' + str(imgname) + '_DCPCN.png', normalize=True,scale_each=False)

至此,第一步跑通,但是由于原图为h5文件无法直观对照结果,简单恢复下,代码:

import h5py
import os
import numpy as np
from PIL import Image
Image_NUM=36
f = h5py.File('./facades/nat_new4/1.h5', 'r')
first_level_key='./facades/nat_new4/'
output_path='./facades/nat_new4/images4'
f.keys() #可以查看所有的主键
print([key for key in f.keys()])
for n in range(Image_NUM):
        path=first_level_key+str(n+1)+'.h5'
        f = h5py.File(path, 'r')
        i=0
        for second_level_key in f.keys():
                second_level_key))
                image = np.array(f[second_level_key])*255
                print(image)
                image = Image.fromarray(np.uint8(image))
                image.save(os.path.join(output_path, '%d'%(n+1)+second_level_key+'%03d.jpeg'%i))
                i+=1
恢复后发现除了最后一组,其他的四幅图都是一样的,从generate_testsample.py中可以看出,测试样本的生成过程就是四个图片都用有雾图片作为输入;其他训练图片恢复后正常,4000+400组正常。

(2)接着用train512(4000个h5文件)微调作者给的预训练模型并保存下来

python train.py --dataroot ./facades/train512 --valDataroot ./facades/val512 --exp ./checkpoints_new --netG ./demo_model/netG_epoch_8.pth

主要问题记录:

(a)

val_batch_output[idx,:,:,:].copy_(dehaze21.data)
修改为:
val_batch_output[idx, :, :, :].copy_(dehaze21.data.view_as(val_batch_output[idx, :, :, :]))

(b)

val_input_cpu, val_target_cpu, val_tran_cpu, val_ato_cpu, imgname = data_val
改为:
val_input_cpu, val_target_cpu, val_tran_cpu, val_ato_cpu = data_val

 (c)

nput_cpu, target_cpu, trans_cpu, ato_cpu, imgname = data
改为:
input_cpu, target_cpu, trans_cpu, ato_cpu = data

(d)

读入数量报错,编号为0-4000,4001个样本,结果显示4005个,经过检查发现有三个名字为“*(1)”的重复样本,手动删除即可

此步骤生成很多结果,极占空间

3 使用自己的有雾图片去雾:

将自己的png格式的有雾图片,放到cvprw_test_resize_crop/文件夹下,用generate_testsample.py生成h5文件,结果输出在facades/test_cvpr/文件夹下,

python demo.py --dataroot ./facades/test_cvpr --valDataroot ./facades/test_cvpr --netG ./checkpoints_new/netG_epoch_8.pth

主要问题记录:

(a)

generate_testsample.py修改一处(路径不一致导致的):

root0 = os.path.abspath('./cvprw_test_resize_crop')

改为:

root = './cvprw_test_resize_crop'

(b)

用自己生成的图片报错,改用原来图片无误,研究下生成图片的程序generate_testsample.py:

RuntimeError: Given groups=1, weight[64, 3, 7, 7], so expected input[1, 4, 256, 512] to have 3 channels, but got 4 channels instead

解决方法:此处是由于输入的图片不是标准的RGB图像,需要在generate_testsample.py中将:

img=misc.imread(item)
改为:
image = Image.open(item)
img = image.convert("RGB")

注意生成的h5文件要是从0开始连续的命名,否则报错

至此,三步复现成功结束。

 

 

 

 

 

 

  • 1
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值