博客简述
根据复现pytorch-CycleGAN-and-pix2pix的代码总结自己的项目结构,方便今后自己的添加模块。
项目地址
项目下载地址:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
参考博客链接:https://blog.csdn.net/weixin_46235765/article/details/119026209?spm=1001.2014.3001.5506
项目整体结构
文件夹checkpoint
该文件夹不仅保存模型的参数,优化器参数,还有loss,epoch等(相当于一个保存模型的文件夹),通常在训练模型的过程中,每隔一段时间就将训练模型信息保存一次(包含模型的参数信息,还包含其他信息,如当前的迭代次数,优化器的参数等,以便用于后面恢复)
文件夹data
该文件夹包括数据的加载和处理以及用户可制作自己的数据集,介绍一下以下七种数据集类。各种类也需要继承BaseData写各种方法。
init_.py: 实现包和train、test脚本之间的接口。
base_dataset.py:继承了 torch 的 dataset 类和抽象基类,该文件还包括了一些常用的图片转换方法,方便后续子类使用。
image_folder.py:目的就是获得指定目录下的图片路径和加载路径图片。
template_dataset.py:为制作自己数据集提供了模板和参考,里面注释一些细节信息,就是提供一个实现自定义数据集的模板。
single_dataset.py:继承BaseDataset类定义最简单的dataset类,只加载指定路径下的一张图片,它可以加载由路径dataroot /path/to/data指定的一组单个图像。它也可以用于生成周期的结果仅为一侧与模型选项-模型测试。
colorization_dataset.py:它可以加载RGB格式的自然图像集,并将RGB格式转换为实验室颜色空间中的(L, ab)对。加载一张 RGB 图片并转化成(L,ab)对在 Lab 彩色空间,pix2pix用来绘制彩色模型。它是基于pix2pixel的着色模型(——模型着色)所需要的。
aligned_dataset.py:从同一个文件夹中加载的是一对图片 {A,B},测试过程中需要准备一个目录/path/到/data/test作为测试数据。
unaligned_dataset.py:从两个不同的文件夹下分别加载 {A},{B} ,在测试期间需要准备两个目录/path/to/data/testA和/path/to/data/testB。
文件夹datasets
存放我们所需要的数据集以及处理数据集的代码
文件夹docs
存放帮助文档.md
文件夹imgs
存放最终效果图片或动图
文件夹models
这里是这个项目的核心代码,各种类需要继承BaseModel写各种方法。
init_.py: 实现包和train、test脚本之间的接口。
base_model.py:继承了抽象类,也包括一些其他常用的函数setup,test,update_learning_rate,save_networks,load_networks,在子类中会被使用。
template_model.py: 实现自己模型的一个模板
pix2pix_model.py:实现了pix2pix 模型,用于在给定成对数据的情况下学习从输入图像到输出图像的映射。
colorization_model.py:继承了pix2pix_model,模型所做的是:将黑白图片映射为彩色图片。
cycle_gan_model.py:来实现cyclegan模型。用于学习图像到图像的转换也即图像翻译,无需成对数据。
networks.py:包含生成器和判别器的网络架构,normalization layers,初始化方法,优化器结构(learning rate policy)GAN的目标函数(vanilla,lsgan,wgangp)。 主要是实现一些普通的实现功能,写一些函数。
test_model.py:此测试模型可用于仅为一个方向生成CycleGAN结果。此模型将自动设置“–dataset_mode single”,它仅从一个集合加载图像。
文件夹options
包含训练模块,测试模块的设置TrainOptions和TestOptions都是 BaseOptions的子类。
base_options.py:定义基础的命令行参数,一般训练和测试都需要的且值不变的命令行参数大多定义在这个文件夹里,如数据集的路径等。
train_options.py:训练需要的options。
test_options.py:测试需要的options。
文件夹results
存放test.py的运行结果
文件夹scripts
存储运行脚本
文件夹util
主要包含一些有用的工具类,如数据的可视化。
get_data.py:用来下载数据集的脚本。
html.py:保存图片写成html。
image_pool.py:此类实现了一个存储之前生成的图像的图像缓冲区,该缓冲区使我们能够使用生成图像的历史更新判别器。
visualizer.py:保存图片,展示图片
utils.py:包含一些辅助函数:tensor2numpy转换,mkdir,诊断网络梯度等
train.py
特点:此脚本适用于各种模型
支持不同的模型(带有选项“-model”:例如,pix2pix、cyclegan、彩色化)。支持不同的数据集模式(带有选项“-dataset_mode”:例如,对齐、未对齐、单一、着色)。
需要指定数据集(’–dataroot’)、实验名称(’–name’)和模型(’–model’)。
它首先在给定选项的情况下创建模型、数据集和可视化工具。然后进行标准的网络培训。在培训期间,它还可以可视化/保存图像、打印/保存损失图和保存模型。
支持继续/暂停训练。使用“–continue_train”恢复先前的培训。
test.py
train.py运行会进行前向传播和反向求导,而test.py模型仅仅进行前向传播。用于图像到图像翻译的通用测试脚本。
使用train.py训练模型后,可以使用此脚本测试模型。它将从–checkpoints\u dir加载保存的模型,并将结果保存到–results\u dir。
它首先在给定选项的情况下创建模型和数据集。它将硬编码一些参数。
然后对–num_测试图像运行推断,并将结果保存到HTML文件中。