——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
论文题目:Dual Contrastive Learning for Unsupervised Image-to-Image Translation
论文地址:https://arxiv.org/abs/2104.07689
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
关键词
深度学习;图像转换;代码复现;DCLGAN;Dual Contrasive Learning
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
一、工作准备
1. 源码下载
Github源代码:https://github.com/JunlinHan/DCLGAN
选择 Code → Download ZIP 下载即可。
之后对文件进行解压,解压后包含如下文件:
☆☆☆----解压后在根目录中分别创建如下两个文件,分别是:
① checkpoints文件(用于存放训练好的模型权重文件)
② results文件(用于存放测试后结果)
2. 数据集下载
本文将利用作者提供的 maps 数据集进行模型训练和测试。
数据集下载链接:https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/
将数据集下载后放入datasets文件并解压,数据集内具体内容如下:
其中,A表示遥感影像,B表示对应的二维地图,如下:
3. 配置必要的环境
可以在environment.yml文件中查看具体所需的环境配置,如下:
其中,Python的版本选用3.6及以上版本都可以,本文所用的Python版本为3.9.12。
对于Conda用户,可以利用下列代码进行环境创建:
conda env create -f 文件存放路径/environment.yml
之后激活此环境
conda activate your_env_name (虚拟环境名称)
本篇Blog将调用远程服务器对代码进行调试和复现,关于本地Windows系统上Pycharm连接远程Linux服务器的相关操作可参考上一篇博客:【深度学习进阶之路】----Pycharm连接远程服务器进行代码调试与开发
**插播:**为了方便在Pycharm中实现远程服务器文件目录的可视化,可以在Pycharm中选择Tools–>Deployment–>Browse Remote Host
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
二、训练数据集
1. 配置训练文件
① 在Pycharm中,点击Run–>Edit Configurations,进行配置训练文件。
其中,--dataroot ./datasets/maps
表示数据集存放位置;--name maps_DCL
表示将在checkpoints文件夹中新建一个名为maps_DCL的文件夹,用于存放训练好的权重;--model dcl
表示使用的模型为dcl。
2. 相关参数的修改
① 在train_options.py
中修改训练epoch及学习率,注意epoch要修改两处,二者之和便为总的训练epoch。
② 在base_options.py
中修改batch size及图片尺寸
当我们进行了上述操作后,便可美美哒运行train.py
啦。
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
**插播:**当博主运行train.py
时,出现以下两条警告信息,虽然不影响模型训练,但还是不想看到红色字体。
UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
出现此条UserWarning
,是因为pytorch不同版本进行更新迭代时引起的警告,某些参数被取代了,解决方案:
self.criterionSim = torch.nn.L1Loss('sum').to(self.device)
改为:
self.criterionSim = torch.nn.L1Loss(reduction='sum').to(self.device)
UserWarning: Argument interpolation should be of type InterpolationMode instead of int. Please, use InterpolationMode enum.
"Argument interpolation should be of type InterpolationMode instead of int. "
出现此条UserWarning
,是torchvision和pillow不兼容导致的,我的环境里torchvision=0.11.3 and pillow=6.1.0,即使我把pillow升级到8.3.1,依然有warning。那只能降低torchvision了,但是torchvision的版本号一般都是和pytorch绑定好的,我们需要不依赖torch来更改torchvison的版本,这可以通过以下指令实现:
self.criterionSim = torch.nn.L1Loss('sum').to(self.device)
改为:
self.criterionSim = torch.nn.L1Loss(reduction='sum').to(self.device)
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
OK,准备就绪,代码已经可以完美的 run了。
大概经过 339s×200epoches≈18.83h 的训练,模型已基本训练完成,如下:
Oh Yeah, Process finished with exit code 0 !!!
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
三、测试训练好的模型
训练好的模型,存放在根目录文件夹 checkpoints 中,如下:
1. 配置测试文件
之后配置测试文件(同“二、2”),便可运行test.py
文件,如下:
其中,–dataroot ./datasets/maps表示数据集存放位置;–name maps_DCL表示将在results文件夹中新建一个名为maps_DCL的文件夹,用于存放测试结果。
Oh Yeah, Process finished with exit code 0 !!!
2. 结果展示
在results文件夹中,点击index.html
即可在线查看模型测试结果,如下:
放几张图,效果嘛,自行体会吧~~~~~~
——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——+——
参考
本篇博客在代码复现过程中,参考了以下几位大神的文章,在此拜谢。