cycleGAN tensorflow框架代码理解
代码是cyclegan的https://github.com/XHUJOY/CycleGAN-tensorflow
注释代码可参考:https://github.com/enjlife/CycleGAN-tensorflow
运行遇到的问题
-
ImportError: libcublas.so.10.0: cannot open shared object file: No such file or directory
这里要看看各种版本的对应情况(可以在tensorflow官网查看),安装好了相应版本之后如果还是报这个错,就在命令行执行sudo ldconfig /usr/local/cuda-10.0/lib64
即可解决。
我使用的版本如下图: -
ValueError: Cannot feed value of shape (0,) for Tensor ‘real_A_and_B_images:0’, which has shape ‘(?, 256,256, 6)’
这个报错看看是不是没有放test图片数据,加上测试图片应该就能运行了
代码解读
程序间调用关系如图所示:
本文的解读顺序为:ops.py->model.py->module.py->urils.py->main.py
ops.py
这个子程序包括6个函数:
①batch_norm()
②instance_norm()
③conv2d()
④deconv2d()
⑤lrelu()
⑥linear()
(但是batch_norm函数以及linear函数在其他模块中好像并未被使用)
- batch_norm函数用来进行归一化,将数据拉回到而且避免发生梯度消失;
- instance_norm对HW做归一化,可以加速模型收敛,并且保持每个图像实例之间的独立;
- conv2d是自带的卷积函数;
- deconv2d是解卷积;
- lrelu函数返回输入与经过激活函数leaky relu之后值的最大值;
- linear函数定义了一个全连接操作。
urils.py
此部分可参考:https://mengfly.github.io/2018/11/24/CycleGAN笔记一.html
- 使用scipy.misc.imread将图片读取为数组,初始化了一个 _imread对象,用来后面加载图片和保存图片;
- get_stddev 是之后构建训练模型的时候初始化模型权重的时候要用到;b
3.定义了 一个ImagePool类,规定了图片缓存上限以及图片存储方法; - 还有加载测试图片的load_test_data函数,加载训练图片的load_train_data函数;
- get_image函数调用imread,center_crop和transform这三个函数,首先调用 imread 将图片加载成数据对象,之后调用 transform 对图像进行标准化,另外在 transform 中,判断是否对图像进行裁剪,裁剪之后,再通过scipy的 resize 方法对图像的大小进行重新调整,保证输入的数据维度都是一定的;
- save_image函数调用imsave函数,inverse_transform函数和merge函数,它先调用 inverse_transform 方法对图片数据进行反标准化,之后调用 imsave方法,对图像进行保存,Imsave调用了 merge 方法,将图像列表合并成了一张图像。这样,保存的就是图片列表合并的图片。
module.py
它包含了两种生成器:generator_unet和generator_resnet、一个判别器discriminator和三个计算误差的函数:abs_criterion,mae_criterion和sce_criterion。
- 生成器和判别器结构就不细讲了;(生成器默认使用Resnet)
- abs_criterion函数计算的是绝对值误差;
- mae_criterion函数计算的是均方误差;
- sce_criterion函数计算的是sigmod误差。
model.py
这个程序只写了一个Cyclegan类,这个类包含了
- 初始化__init__函数;
- build_model构建模型函数;
- train训练函数;
- test测试函数;
- load加载模型函数;
- sample_model输出样本结果函数;
- save模型存储函数。
main.py
main.py有一个参数管理的模块parser便于改变输入输出参数。
main函数是先确定程序路径, 创建session,创建cyclegan实例并根据参数判断是进行训练还是进行测试。
参考
测试结果
本人测试的是图像着色任务,和原目标不一样,原目标结果参阅论文:https://arxiv.org/pdf/1703.10593.pdf
着色任务结果稍后贴出,参数设置为:
(输入图像本应该是单通道,但为了代码成功运行,本次测试仍然保持输入通道为3)