从零开始的CycleGAN学习笔记 代码分析

源代码地址:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

本文结合源代码中所给的./docs/overview.md 文件,对于原文中所给的整体代码内容进行简单分析,以达到进一步学习、了解CycleGAN这一网络结构的目的。
目前先进行总体的分析,等后期看情况对代码做逐行分析。

前言

(待补充)

1. train.py

1.1 概览

原话:

train.py is a general-purpose training script. It works for various
models (with option --model: e.g., pix2pix, cyclegan, colorization)
and different datasets (with option --dataset_mode: e.g., aligned,
unaligned, single, colorization).

简单来说就是一个综合的训练用文件,用于train各种模型,并且可以选择不同的dataset来用。
可以通过设置不同的option来达到不同的训练目的,这部分在option部分再详细说明,简单来说比较常用的有原文提到的 --model 来选择模型等。

1.2 调用

原代码中注释给出了一般情况下的调用方式。
举例:
CycleGAN model:
python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
pix2pix model:
python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA

--dataroot 为选择dataset的根目录 --name 为模型名称 --model为选择模型 

1.3 Code部分

大体流程是

  1. 读取train_options.py中的option 存为opt

  2. 利用data中的create_dataset类 根据opt创建对应的dataset

  3. 获取对应dataset长度 存为dataset_size

  4. 同样的方式使用create_model创建相应的model

  5. 上述各步运行完后均会输出相应的情况
    options - dataset - dataset_size - model(这一步里头没太看懂option是怎么最先输出出来的,待解决
    运行到这里的

    loading the model from ./checkpoints\horse2zebra\latest_net_G_A.pth
    

    之类的内容是由dataset和model创建的过程中调用到的函数输出的。

  6. 使用Visualizer(opt) 创建visualizer实现可视化,我用的是默认的visdom方法,登录默认网址http://localhost:8097 对于训练情况进行实时观测。
    训练期间并没有遇到别的文章评论区提到的不能进行可视化的问题。
    我自己的话是安装了visdom之后,用cmd运行python -m visdom.server 后在pycharm的terminal中跑代码,正常运行没遇到啥问题。

  7. 把总iter置零,没什么好说的

  8. 根据opt中所获取的epoch开始循环,总epoch在这里是200(opt.niter + opt.niter_decay + 1) 这里的话看了下option文档的描述,niter和niter_decay初始值都是100,代表的是学习率衰减到0的值(? 待解决

  • 每个epoch会reset visualizer,实现至少每个epoch更新一次
  • 每个epoch中进行内部循环,实现
    iter的更新
    data的输入 model.set_input(data)
    模型的更新 model.optimize_parameters() 计算loss、梯度下降、更新权重等
    (真正的训练部分 实际上这部分内容不同model有不同呈现)

1.4 循环中的部分参数

  • opt.display_freq :被total_iters整除时将结果展示于visdom并存到html文件中
  • opt.print_freq :被total_iters整除时将loss展示于visdom并输出到工作台
  • opt.save_latest_freq:被total_iters整除时保存latest模型
  • opt.save_epoch_freq:被epoch整除时保存latest和对应epoch名命名的模型(这里是5)
  • 每个epoch结束会输出相应情况以及总时间

2.test.py

2.1 概览

原话:

test.py is a general-purpose test script. Once you have trained your
model with train.py, you can use this script to test the model. It
will load a saved model from --checkpoints_dir and save the results to
–results_dir.

简单来说就是测试用的文件,用来test我们用train.py训练好的model,可以从
–checkpoints_dir来读取一个先前保存的model

2.2 调用

源代码注释中所给出的调用方式为:
CycleGAN model (both sides):
python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan

CycleGAN model (one side only):
python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout

调用中涉及的参数和train.py差不多 
特别提到的是--model test仅供one side 使用,该选项会自动设定--dataset_mode single

2.3 Code部分

大体流程如下:

  1. 读取test_options.py中的option 存为opt
  2. 重新设定了部分test中不能改变的参数(最后一部分再说)
  3. 读取dataset和model的两步的方式与train.py中完全一致
  4. 创建一个网页,用于存储test集的图片
  5. 利用dataset测试,通过model.test()获得相应结果,将图片存到html网页中
  6. 保存网页

源代码在这里提到

test with eval mode. This only affects layers like batchnorm and dropout.
For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.

若使用检验模式(eval mode)进行test仅影响到bn和dropout(这块还想再测试下,待解决),这两种方法只在pix2pix方法中涉及到,并不会影响CycleGAN的结果。

2.4 参数

opt.num_threads = 0
仅支持单线程(没理解错的话)
opt.batch_size = 1
仅支持batch_size为1
opt.serial_batches = True
停止随机取样,如果需要随机选择图片产生结果则把这行注释掉
opt.no_flip = True
停止翻转(flip)
opt.display_id = -1
不使用visdom呈现结果,test结果保存为一个html文件

3.Data部分

3.1 概览

首先看一下原文中大概是这么说的:

data directory contains all the modules related to data loading and preprocessing. To add a custom dataset class called dummy, you need to add a file called dummy_dataset.py and define a subclass DummyDataset inherited from BaseDataset. You need to implement four functions: init (initialize the class, you need to first call BaseDataset.init(self, opt)), len (return the size of dataset), getitem (get a data point), and optionally modify_commandline_options (add dataset-specific options and set default options). Now you can use the dataset class by specifying flag --dataset_mode dummy.

data dictionary 包括和数据读取还有预处理相关的所有模块。如果要添加一个叫dummy的数据集,我们需要添加一个叫做dummy_dataset.py的文件,并定义一个继承于BaseDataset的子类DummyDataset,我们需要4个函数(原文中所提到的)来建立类(这个在后头说)。

3.2 Code部分

3.2.1 init.py

定义的函数
  1. find_dataset_using_name(dataset_name):
    利用dataset_name引入文件名为"“data/[dataset_name]_dataset.py”"对应的模块,从中找到Basedataset的子类dataset并返回。若未能找到(dataset is None) 则会报NotImplementedError。

  2. get_option_setter(dataset_name):
    如同字面义,获得设定option的setter,实际上是调用了上面的find_dataset_using_name(dataset_name)并返回dataset_class.modify_commandline_options。(这个在base_dataset里头说)

  3. create_dataset(opt):
    根据opt创建dataset,直接看这段注释里的原文就完事了。

Create a dataset given the option.
This function wraps the class CustomDatasetDataLoader.
This is the main interface between this package and ‘train.py’/‘test.py’
Example:
>>> from data import create_dataset
>>> dataset = create_dataset(opt)

操作上就是调用了下 CustomDatasetDataLoader类中的load_data()功能,就是个接口没什么好说的。

定义的类

CustomDatasetDataLoader()
data的包装类,包括四个function

  1. __init __(self, opt):

Initialize this class
Step 1: create a dataset instance given the name [dataset_mode]
Step 2: create a multi-threaded data loader.

初始化,建立一个类的实例和一个多线程数据loader。
操作上来说,利用find_dataset_using_name(opt.dataset_mode)和self.dataset = dataset_class(opt)创建类的实例,并输出数据集创建的信息。
读取器loader用的是pytorch的接口torch.utils.data.DataLoader(),根据预设参数读取相应的data。

  1. load_data(self):
    return self,没啥好说的。

  2. __ len__(self):
    返回数据集的长度。

  3. __ iter__(self):
    返回一个batch的data

3.2.2 base_dataset.py

主要是建立抽象基类BaseDataset,这部分后面再补充

3.2.3 image_folder.py

待补充

3.2.4 template_dataset.py

一个模板类

3.3 其他

待补充

4.Model部分

4.1 概览

先看下原文是这么说的:

models directory contains modules related to objective functions, optimizations, and network architectures. To add a custom model class called dummy, you need to add a file called dummy_model.py and define a subclass DummyModel inherited from BaseModel. You need to implement four functions: init (initialize the class; you need to first call BaseModel.init(self, opt)), set_input (unpack data from dataset and apply preprocessing), forward (generate intermediate results), optimize_parameters (calculate loss, gradients, and update network weights), and optionally modify_commandline_options (add model-specific options and set default options). Now you can use the model class by specifying flag --model dummy.

简单来说,这一部分包括了与目标函数、优化还有网络结构相关的模块。如果我们要创建一个叫dummy的模型,跟dataset部分的处理相似,我们需要添加一个名为dummy_model.py的文件,并定义一个继承于BaseModel的子类 DummyModel,同样要包含四个函数。

  1. __ init__:初始化类,利用BaseModel.__ init__(self, opt)
  2. set_input:从数据集中提取data并预处理。
  3. forward:产生intermediate的结果(中间结果)
  4. optimize_parameters:计算loss、梯度并更新网络权重
  5. (可选) modify_commandline_options:添加一些专属于模型的option并设定默认值(datasets那里也有这个)
    由于这部分比较关心的是cycleGAN的情况,所以着重看cycle_gan_model.py 部分的内容,其他部分一笔带过,留待后面补充。

4.2 cycle_gan_model.py

4.2.1 原文说了些什么

题外话:这部分用到了ImagePool(在util中定义,用于存储先前generate出来的image,在相关文件中有)
源文件中对于这部分的解释:

cycle_gan_model.py implements the CycleGAN model, for learning image-to-image translation without paired data. The model training requires --dataset_mode unaligned dataset. By default, it uses a --netG resnet_9blocks ResNet generator, a --netD basic discriminator (PatchGAN introduced by pix2pix), and a least-square GANs objective (–gan_mode lsgan).

该文件包括了CycleGAN模型,用于学习无配对数据情形下的image-to-image的translation过程。
该模型需要设置 --dataset_mode unaligned dataset (显然)。默认条件下,使用的G为resnet_9blocks ResNet generator 残差网络结构(这块相关知识想再看看,待解决),使用的D是PatchGAN introduced by pix2pix,GANmode是LSGAN(参考此网址)(简单来说是最小二乘loss而非交叉熵loss,这块相关知识也想再看看,待解决)。

4.2.2 Code部分

这部分直接一行一行分析下去有点僵硬,先大体说下情况。
除了刚才提到的ImagePool以外,具体的网络结构中涉及到的Generator和Discriminator都是用network.py中相应的函数生成的,那个文件中包含了更多的在网络设计上的细节,而这个文件中则更多的是拓扑意义上的网络结构定义,看完这部分的code,很多在跑代码伊始时便产生的困惑迎刃而解了,特别是那个idt_A和idt_B,跑了几十个epoch的时候都根本没看出来是用来干嘛的,百度了半天也没看懂,但是看完这部分代码之后就清楚了许多。
这个文件其实就是定义了一个继承于BaseModel(相应抽象基类)的子类CycleGANModel。

定义的函数
1.modify_commandline_options(parser, is_train=True):

本身是一个optional的function,在这个model中用到,用于添加了几个额外的参数:

  1. lambda_A:用于给Forward cycle loss加权(默认为10.0)
    Forward cycle loss = lambda_A * ||G_B(G_A(A)) - A||
    A → B → A的cycle loss
  2. lambda_B:用于给Backward cycle loss加权(默认为10.0)
    Backward cycle loss = lambda_B * ||G_A(G_B(B)) - B||
    B → A → B的cycle loss
  3. lambda_identity(optional,这个一开始没太弄懂):给Identity loss 加权。(默认为0.5)
    Identity loss = lambda_identity * (||G_A(B) - B|| * lambda_B + ||G_B(A) - A|| * lambda_A)
    这里把B(比如斑马)扔到G_A(比如把马变成斑马)的生成器中并算loss的目的是在于防止画风的过度迁移(我的理解的话是防止一些没必要的要素,比如非马的其他部分迁移过去),通过这个loss,可以看G_A(B)和B的区别,如果没有什么区别的话说明G_A对于斑马是没有太大影响的,只是把A(马)变成了B domain 中的斑马这样。
2.__ init__(self, opt):

根据option初始化CycleGAN。

  1. BaseModel.init(self, opt)
    不用多说啥了,必须的调用。
  2. 明确要print的loss(‘D_A’, ‘G_A’, ‘cycle_A’, ‘idt_A’, ‘D_B’, ‘G_B’, ‘cycle_B’, ‘idt_B’)
    train和test会调用BaseModel.get_current_losses
  3. 明确要print的image(‘real_A’, ‘fake_B’, ‘rec_A’,‘real_B’, ‘fake_A’, ‘rec_B’)这里由于default的lambda_identity == 0.5,所以还有’idt_B’和’idt_A’这两张图。
  4. 明确要保存的model
    若是train则为G_A G_B D_A D_B
    若是test则只有G_A G_B
  5. 定义network
    调用networks.define_G 和 networks.define_D,根据需要的条件创建,**改网络的话在这要做调整。**如果 isTrain == True 则要定义D,否则只需要定义G
  6. 对于train的情况,做出了一些相应的细节处理
    比如:
    在opt.lambda_identity > 0.0的情况下断言输入输出channel数相同;
    创建两个pool:fake_A_pool和fake_B_pool来存储过往数据
    定义了GANloss(用的是在network.py里头定义的loss)还有Cycle以及Idtloss(用的是L1loss,直接调的torch.nn.L1loss)
    初始化optimizerG和D,用的都是AdamOptimizer,调了下相关参数。
3.set_input(self, input):

根据设定的方向决定input(A→B还是B→A),也可以用来交换域。

4.forward(self):

fake_B = G_A(A) A域图片经过G_A生成的伪B,目的是骗过D_B
rec_A = G_B(G_A(A)) 其实就是把fake_B扔进G_B生成的A,一般来说要求和原图尽可能接近
另外的fake_A和rec_B同理,就不说了。
特别地,(其实这两个不在这里)
idt_A 和 idt_B 分别对应于G_A(B) 和 G_B(A)
意思是idt_A是把B(比如斑马)扔到G_A(比如把马变成斑马)的生成器中产生的结果,这里idt_A应与原图(B)越接近越好,表明G_A对B域图片几乎没有影响。

5.backward_D_basic(self, netD, real, fake):

计算D的loss
loss_D = (loss_D_real + loss_D_fake) * 0.5
而后调用loss_D.backward()计算gradient

其他几个函数都是相仿的,比较没意思就不写了(摸了 )。

4.2.3 loss部分

loss_G_A:GAN loss D_A(G_A(A))
loss_G_B:GAN loss D_B(G_B(B))
loss_cycle_A:Forward cycle loss || G_B(G_A(A)) - A|| * lambda_A 这里用的是L1loss
loss_cycle_B:Backward cycle loss || G_A(G_B(B)) - B|| * lambda_B
loss_D_A:self.backward_D_basic(self.netD_A, self.real_B, fake_B) 求平均的结果
loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A) 同理
loss_idt_A = ||G_A(B) - B|| * lambda_B * lambda_idt 这里用的是L1loss
loss_idt_B = ||G_B(A) - A|| * lambda_A * lambda_idt

  • 22
    点赞
  • 125
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值