参考大佬文章和资料:
https://blog.csdn.net/qq_32734095/article/details/89059949
https://blog.csdn.net/qq_41202069/article/details/106589603
1、环境安装:
1.1 要求
Anaconda3
Ubuntu 18.04.5 LTS (cat /etc/issue)
Python3.6环境
Pytorch=0.3.1框架
CUDA9.1+CUDNN7.5
1.2 主要过程和问题:
(1)安装Pytorch>=0.3.1等
conda create -n pt031 python=3.6
conda activate pt031
pip3 install http://download.pytorch.org/whl/cpu/torch-0.3.1-cp36-cp36m-linux_x86_64.whl
pip install torchvision==0.2.1
pip install h5py
pip install scikit-image
cuda和cudnn安装遇到问题,由于本机安装了cuda10和tf等,折腾两个cuda版本并存,搞坏了pytorch,重新安装:
pip install torch-0.3.1-cp36-cp36m-linux_x86_64.whl --force-reinstall
(2)安装cuda、cudnn,下载,官网慢,在csdn下载找到可用资源付费下载
安装cuda:
sudo ./cuda_9.1.85_387.26_linux.run (报错)
sudo ./cuda_9.1.85_387.26_linux.run --override #(暂时可用,应该降gcc版本)
安装cudnn:
cp cudnn-9.1-linux-x64-v7.solitairetheme8 cudnn-9.1-linux-x64-v7.tgz
tar -zxvf cudnn-9.1-linux-x64-v7.tgz
sudo cp cudnn.h /usr/local/cuda-9.1/include/
sudo cp cudnn.h /usr/local/cuda-9.1/include/
两个版本cuda切换:
cd /usr/local/
sudo rm -rf cuda #(删除原来的软链接cuda10)
sudo ln -s /usr/local/cuda-9.1 /usr/local/cuda #(创建cuda9.1的软链接)
后期又遇到了其他安装包不兼容的问题,基本百度可以解决
2 复现复现过程参见参考连接,遇到的问题简单记录下(后补充的所以问题无图片截图)
(1)首先根据作者给出的预训练模型测试nat_new4(36个真实有雾图片的h5文件)
python demo.py --dataroot ./facades/nat_new4 --valDataroot ./facades/nat_new4 --netG ./demo_model/netG_epoch_8.pth
主要问题记录:
(a)
if opt.netG != '': netG.load_state_dict(torch.load(opt.netG))
报错,查看后发现模型读入后部分层丢失,tran_dense.dense_block1和tran_dense.dense_block2都不见了,原因不明(pytorch小白,后期还要补充理论知识),按照以下修改后可用,并且层没有丢失 :
if opt.netG != '': netG.load_state_dict(torch.load(opt.netG),strict=False)
(b)
input_cpu, target_cpu, depth_cpu, ato_cpu, imgname = data
报错,需要5个参数只得到4个,查阅分析代码后修改如下,简单粗暴的将文件名称这一个参数去掉,但由于后面打印还需要此参数,故将此参数用i来表示,目前跑代码没有影响:
input_cpu, target_cpu, depth_cpu, ato_cpu, imgname = data imgname = i
打印部分也相应修改
vutils.save_image(zz1, './result_cvpr18/image/real_dehazed/'+imgname[index2]+'_DCPCN.png', normalize=True, scale_each=False) 改为: vutils.save_image(zz1, './result_cvpr18/image/real_dehazed/' + str(imgname) + '_DCPCN.png', normalize=True,scale_each=False)
至此,第一步跑通,但是由于原图为h5文件无法直观对照结果,简单恢复下,代码:
import h5py
import os
import numpy as np
from PIL import Image
Image_NUM=36
f = h5py.File('./facades/nat_new4/1.h5', 'r')
first_level_key='./facades/nat_new4/'
output_path='./facades/nat_new4/images4'
f.keys() #可以查看所有的主键
print([key for key in f.keys()])
for n in range(Image_NUM):
path=first_level_key+str(n+1)+'.h5'
f = h5py.File(path, 'r')
i=0
for second_level_key in f.keys():
second_level_key))
image = np.array(f[second_level_key])*255
print(image)
image = Image.fromarray(np.uint8(image))
image.save(os.path.join(output_path, '%d'%(n+1)+second_level_key+'%03d.jpeg'%i))
i+=1
恢复后发现除了最后一组,其他的四幅图都是一样的,从generate_testsample.py中可以看出,测试样本的生成过程就是四个图片都用有雾图片作为输入;其他训练图片恢复后正常,4000+400组正常。
(2)接着用train512(4000个h5文件)微调作者给的预训练模型并保存下来
python train.py --dataroot ./facades/train512 --valDataroot ./facades/val512 --exp ./checkpoints_new --netG ./demo_model/netG_epoch_8.pth
主要问题记录:
(a)
val_batch_output[idx,:,:,:].copy_(dehaze21.data) 修改为: val_batch_output[idx, :, :, :].copy_(dehaze21.data.view_as(val_batch_output[idx, :, :, :]))
(b)
val_input_cpu, val_target_cpu, val_tran_cpu, val_ato_cpu, imgname = data_val 改为: val_input_cpu, val_target_cpu, val_tran_cpu, val_ato_cpu = data_val
(c)
nput_cpu, target_cpu, trans_cpu, ato_cpu, imgname = data 改为: input_cpu, target_cpu, trans_cpu, ato_cpu = data
(d)
读入数量报错,编号为0-4000,4001个样本,结果显示4005个,经过检查发现有三个名字为“*(1)”的重复样本,手动删除即可
此步骤生成很多结果,极占空间
3 使用自己的有雾图片去雾:
将自己的png格式的有雾图片,放到cvprw_test_resize_crop/文件夹下,用generate_testsample.py生成h5文件,结果输出在facades/test_cvpr/文件夹下,
python demo.py --dataroot ./facades/test_cvpr --valDataroot ./facades/test_cvpr --netG ./checkpoints_new/netG_epoch_8.pth
主要问题记录:
(a)
generate_testsample.py修改一处(路径不一致导致的):
root0 = os.path.abspath('./cvprw_test_resize_crop')
改为:
root = './cvprw_test_resize_crop'
(b)
用自己生成的图片报错,改用原来图片无误,研究下生成图片的程序generate_testsample.py:
RuntimeError: Given groups=1, weight[64, 3, 7, 7], so expected input[1, 4, 256, 512] to have 3 channels, but got 4 channels instead
解决方法:此处是由于输入的图片不是标准的RGB图像,需要在generate_testsample.py中将:
img=misc.imread(item) 改为: image = Image.open(item) img = image.convert("RGB")
注意生成的h5文件要是从0开始连续的命名,否则报错
至此,三步复现成功结束。