基于飞桨paddlepaddle复现StarGAN v2
StarGAN v2: Diverse Image Synthesis for Multiple Domains
通过对论文的精读,完全掌握StarGAN v2
首先还是看下StarGAN v2在CelebA-HQ和AFHQ两个数据集上的表现(没打label的是实际的样张)
名词解释
- domian:是指一组图像,具有一个独特类别的图像组。(男、女;猫、狗、老虎、豹)
- style:是在一个domain中,每个图片都有自己独特的style.(对人来说可以是发型、发色、表情、是否有胡子、瞳孔颜色等等;对动物来讲可以是毛色、品种、也包含表情等等)
引言
-
一个理想的图像到图像的翻译方法应该能够综合考虑各个领域的不同风格的图像。也就是说一个好的gan就要能将我们输入的图片按照domain和style生成新的图片.StarGAN仍然会为每个domain学习一个确定性映射,它不能捕获数据分布的多模态特性。这个限制来自于这样一个事实,即每个domain都由一个预定的label表示。注意,Generator接收一个固定的label(例如一个热向量)作为输入,因此它不可避免地在给定源图像的每个domain中生成相同的输出。
-
解决风格多样性问题,通常是将标准高斯分布随机抽样出来的低维latent code(隐编码)送人生成器得到的.在生成图片时把通过latent code方式得到的Style加入到图片生成器中,来得到相应的图片。这样的方法就只考虑两个domain之间映射,不能扩展到不断增加的域数量。如果有K个domain我们就需要训练K(K-1)个生成器去处理每个域之间的转换,这样就限制了他们的应用.
-
为了解决可伸缩性的问题,早版本的StarGAN是通过使用一个生成器来映射所有的Domain(stargan的名字就是因为这样模型结构而得名的).Generator将domain label作为附加输入,并学习将image转换为相应的Domain。
-
想要两全其美,就有了现在的Starganv2,他是一种可伸缩的方法,可以在多个Domain中生成不同的图像。特别是,他从StarGAN开始,提出的领域特定样式代码替换其domain label,该代码可以表示特定Domain的不同样式。为此,先介绍了两个模块,一个Mapping Network(映射网络)和一个Style Encoder(风格编码器)。mapping network学习将随机高斯噪声转换为Style code,而style encoder从给定的参考图像中提取style code。考虑多重因素Domain,两个模块都有多个输出分支,每个分支都提供特定Domain的style code。最后,利用这些样式代码,Starganv2的Generator学会了成功地在多个Domain上合成不同的图像.
-
Starganv2论文还开源了一个新的动物脸数据集(AFHQ),Animal Faces-HQ.
Stargan v2基本框架
设X和Y分别是图像和可能Domain的集合。在给定图像x∈x和任意Domainy∈y的情况下,我们的目标是训练一个能够生成与图像x对应的每个Domain y的不同图像的生成器G,在每个Domain的学习风格空间中生成特定于Domain的Style code,训练G来反映Style code。上图图展示了StarGAN v2的概述,它由下面描述的四个模块组成。
Generator G
-
Generator:生成器G将输入图像x转换为输出图像G(x,s),该图像反映由(Mapping Network)F或(Style Encoder)E提供的Domain特定样式代码s(style encoder)。使用自适应实例规范化(AdaIN)将s注入G中。可以观察到s被设计成表示特定Domain y的样式,这就消除了提供y到G的必要性,并允许G合成所有Domain的图像。
Generator包含4个上采样模块,4个中间层模块和4个下采样模块,所有模块都继承了预激活的参差单元。还在上、下采样模块中用了instance normalization (IN)和adaptive instance normalization (AdaIN).把Style code 通过AdaIN注入所有的AdaIN中,并通过学习仿射变换提供缩放和移位向量。
对CelebA-HQ还增加了上、下采样模块各一层,移除了上采样模块的residual blocks,增加了wing,将生成人脸关键部位的mask,使得人脸mask区域在通过Generator后仍能得以保留.
AdaIN
- 本文通过style code获得style_mean(ys)和style_std(yb),通过style_mean和style_std改变feature map的分布来实现风格迁移。
Mapping Network F
Mapping Network:给定一个latent code z和一个Domain y,(Mapping Network)F生成一个样式代码s=F y(z),其中fy(·)表示对应于Domain y的F的输出。F由具有多个输出分支的MLP组成,为所有可用Domain提供Style code。F通过对latent code z∈z和Domain y∈y进行采样,可以产生不同Style code。多任务体系结构允许F高效有效地学习所有Domain的Style code。
Mapping Network由一个具有K个输出分支的MLP组成,其中K表示Domain的数目。由4个共享Domain权重的全连接网络,再加上4个每个domain各自独立的全连接网络(相当于domain个网络)。latent code尺寸是16,hidden layer尺寸是512,style code尺寸是64.latent code是从标准高斯分布采样得到的。
作者尝试将像素标准化加入latent code没有提高实际效果,作者还试过在此部分加feature normalizations,不过表现也不怎么好。
Style Encoder E
Style Encoder:给定一个图像x及其对应的Domain y,Style Encoder E提取x的样式代码s=E y(x)。这里,E y(·)表示E对应于Domain y的输出。与Mapping Network F类似,Style Encoder E受益于多任务学习设置。E可以使用不同的参考图像生成不同Style code。这就可以使得G合成时参考了图像x的Style code输出图像。
Style Encoder由一个具有K个输出分支的CNN组成,其中K表示Domain的数目。由6个共享Domain权重的ResBlk,和一个每个domain独立的全连接网络(相当于domain个网络)。没有使用global average pooling 提取参考图像的精细特征。D=64 是输出的Style Encoder尺寸.
Discriminator D
Discriminator:鉴别器D是一个多任务鉴别器,它由多个输出分支组成。每个分支D y学习二值分类,确定图像x是其Domain y的真实图像还是由G生成的伪图像G(x,s)。
Discriminator是由6个ResBlk,激活函数采用leaky ReLU,用K(domain个数)个全连接输出每个domain real or fake的分类。D=1,用来判断图片是真是假。没有使用feature normalization或者PatchGAN。
Style Encoder和Discriminator的大体结构基本是一样的。
Training objectives
给定图片 x ∈ X 他的 domain y ∈ Y,通过以下损失函数来训练我们的网络
1.Adversarial objective.
- GAN的一般损失
- 还使用了R1 正则 ,即该文的zero-centered gradient penalty(又称为 the - log D trick),其公式为,即鉴别器输出对真实图像的导数的模的平方:
2.Style reconstruction
- 目的为了使Generator在生成图片时能够运用style code,训练一个Style Encoder E来鼓励多个Domain的不同输出。学习的Style Encoder E允许Generator转换输入图像,反映参考图像的Style code。
3.Style diversification
- 源自MSGAN(去掉了分母项,训练更加稳定),目的使生成器G能够产生不同的图像,最大化正则化项迫使Generator去探索图像空间并发现有意义的style code以生成不同的图像。
4.Preserving source characteristics.
Cycle consistency loss
- 源自CycleGAN 的损失,目的为了让生成的图片适当的保留原始图片Domain的特征
5.Full objective
其中的超参数作者给出了实际训练值
原作训练的细节:采用的batch_size=8;iterations=100k;λsty=1,λds=1,λcyc=1 for CelebA-HQ;λsty=1,λds=2,λcyc=1 for AFHQ;λds根据iterations线性衰减到0;采用非饱和对抗性损失并加入参数λ=1的R1正则;Adam优化器β1=0,β2=0.99;lr=1e-4 for G D E,f_lr=1e-6 for F;除了模型D,在推理是使用滑动平均(EMA),参数采用He initialization初始化,biases=0;在AdaIN中biases=1.
最后在Tesla V100 上训练了三天;
训练过程
- 训练参数说明
- 输入数据:(黄色背景)
- x_real:输入的真实图片
- y_org:输入图片的Domain
- y_trg:参考图片的Domain
- x_ref:参考图片
- x_ref2:参考图片
- z_trg:随机噪声
- z_trg2:随机噪声
- 输出数据(中间变量):(绿色背景)
- s_try(2):要输入Generator的style code,由Mapping network根据z_trg和y_trg生成,或者Style encoder根据x_ref和y_trg生成
- x_fake(2):由Generator根据输入图片x_real和s_try生成的假图片
- s_org:由真实图片和其Domain通过Style encoder生成的style code
- s_pred:由x_fake和s_try通过Style encoder生成的style code
- x_rec:由s_pred和x_fake通过Generator生成的假图片
计算d_loss后,更新D的参数
计算g_loss后,更新E、M、G的参数
小结
StarGAN v2:
- 采用Style Encoder,Mapping Network来得到图片的style code。其中Style Encoder来获取参考图片的style code,用来生成我们想要的图片;Mapping Network来随机生成style code,保证style分布多样性。
- 多domain输出,除了Generator网络,其他三个网络都是多domain输出,让一个模型就可以学得多个domain的信息。
- 自适应实例归一化(AdaIN)层,AdaIN层与BN、IN类似,都是在网络内部改变feature map的分布,实现风格迁移。
- 借鉴CycleGAN了思想,使生成图片与原图片相近。
- 增加了style diversity loss与 R1 正则,保证训练的收敛。
- 利用EMA,能使得模型更加的鲁棒。
- 一个预训练好的人脸关键点模型FAN,产生关键部位的mask,使得原图像mask区域在转换后仍能得以保留。
- to be continue…
复现代码 to be continue…
写的不好,而且现在还有很多问题,先不公开了,有效果时我会整理好发布的。
D loss定义
def d_train_step(self, x_real, y_org, y_trg, z_trg=None, x_ref=None):
#no_grad()下的内容不计算梯度,只训练D,可以减少计算以显存占用
with fluid.dygraph.no_grad():
if z_trg is not None:
s_trg = self.mapping_network([z_trg, y_trg])
else: # x_ref is not None
s_trg = self.style_encoder([x_ref, y_trg])
x_fake = self.generator([x_real, s_trg])
real_logit = self.discriminator([x_real, y_org])
fake_logit = self.discriminator([x_fake, y_trg])
real_loss , fake_loss = self.adv_weight * discriminator_loss(self.gan_type, real_logit, fake_logit)
d_adv_loss= real_loss + fake_loss
#r1_gp_req 这个函数还没有实现,paddle.fluid.dygraph.grad 这个API还有问题计算不了discriminator的梯度
if self.gan_type == 'gan-gp':
d_adv_loss += self.r1_weight * r1_gp_req(real_logit, x_real)
d_loss = d_adv_loss
return real_loss , fake_loss, d_adv_loss, d_loss
G loss定义
def g_train_step(self, x_real, y_org, y_trg, z_trgs=None, x_refs=None):
if z_trgs is not None:
z_trg, z_trg2 = z_trgs
if x_refs is not None:
x_ref, x_ref2 = x_refs
# adversarial loss
if z_trgs is not None:
s_trg = self.mapping_network([z_trg, y_trg])
else:
s_trg = self.style_encoder([x_ref, y_trg])
x_fake = self.generator([x_real, s_trg])
fake_logit = self.discriminator([x_fake, y_trg])
g_adv_loss = self.adv_weight * generator_loss(self.gan_type, fake_logit)
# style reconstruction loss
s_pred = self.style_encoder([x_fake, y_trg])
g_sty_loss = self.sty_weight * L1_loss(s_pred, s_trg)
# diversity sensitive loss
if z_trgs is not None:
s_trg2 = self.mapping_network([z_trg2, y_trg])
else:
s_trg2 = self.style_encoder([x_ref2, y_trg])
x_fake2 = self.generator([x_real, s_trg2])
# x_fake2 = tf.stop_gradient(x_fake2.stop_gradient = True
x_fake2.stop_gradient = True
#x_fake2.numpy()
g_ds_loss = -self.ds_weight * L1_loss(x_fake, x_fake2)
# cycle-consistency loss
s_org = self.style_encoder([x_real, y_org])
x_rec = self.generator([x_fake, s_org])
g_cyc_loss = self.cyc_weight * L1_loss(x_rec, x_real)
g_loss = g_adv_loss + g_sty_loss + g_ds_loss + g_cyc_loss
return g_adv_loss, g_sty_loss, g_ds_loss, g_cyc_loss, g_loss
r1_reg 函数
这里计算梯度的API有些问题,还没搞出来
EMA
这里感谢老董的分享
https://github.com/dbsxdbsx
def soft_update(target, source, decay):
"""
Copies the parameters from source network (x) to target network (y)
using the below update
y = decay * source + (1 - decay) * target_param
:param target: Target network (PaddleDynaGraphModel)
:param source: Source network (PaddleDynaGraphModel)
:decay: decay ratio should be super lower than 1, in range of [0,1]
:return:
https://paddlepaddle.org.cn/documentation/docs/zh/api_cn/fluid_cn/Variable_cn.html#set-value
"""
target_model_map = dict(target.named_parameters())
for param_name, source_param in source.named_parameters():
target_param = target_model_map[param_name]
target_param.set_value(decay * source_param +
(1.0 - decay) * target_param)
!unzip data/data42681/afhq.zip -d ./stargan-v2-paddle/dataset/
!pip install tqdm==4.46.1
%cd stargan-v2-paddle/
!python main.py --dataset afhq --phase train
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Collecting tqdm==4.46.1
[?25l Downloading https://mirrors.tuna.tsinghua.edu.cn/pypi/web/packages/f3/76/4697ce203a3d42b2ead61127b35e5fcc26bba9a35c03b32a2bd342a4c869/tqdm-4.46.1-py2.py3-none-any.whl (63kB)
[K |████████████████████████████████| 71kB 20.9MB/s eta 0:00:01
[?25hInstalling collected packages: tqdm
Found existing installation: tqdm 4.36.1
Uninstalling tqdm-4.36.1:
Successfully uninstalled tqdm-4.36.1
Successfully installed tqdm-4.46.1
/home/aistudio/stargan-v2-paddle
##### Information #####
# gan type : gan
# dataset : afhq
# domain_list : ['cat', 'dog', 'wild']
# batch_size : 12
# max iteration : 100000
# ds iteration : 100000
##### Generator #####
# latent_dim : 16
# style_dim : 64
# num_style : 5
##### Mapping Network #####
# hidden_dim : 512
##### Discriminator #####
# spectral normalization : False
len(records) 14630
['./dataset/afhq/train/cat/pixabay_cat_000065.jpg', './dataset/afhq/train/cat/pixabay_cat_003360.jpg', [0]]
Dataset number : 14630
W0803 02:25:53.415187 140 device_context.cc:252] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0803 02:25:53.419502 140 device_context.cc:260] device: 0, cuDNN Version: 7.3.
!pip install tqdm==4.46.1
%cd stargan-v2-paddle/
!python main.py --dataset afhq --phase test
Looking in indexes: https://pypi.mirrors.ustc.edu.cn/simple/
Requirement already satisfied: tqdm==4.46.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (4.46.1)
[Errno 2] No such file or directory: 'stargan-v2-paddle/'
/home/aistudio/stargan-v2-paddle
##### Information #####
# gan type : gan
# dataset : afhq
# domain_list : ['cat', 'dog', 'wild']
# batch_size : 5
# max iteration : 100000
# ds iteration : 100000
##### Generator #####
# latent_dim : 16
# style_dim : 64
# num_style : 5
##### Mapping Network #####
# hidden_dim : 512
##### Discriminator #####
# spectral normalization : False
W0803 02:18:52.507735 2067 device_context.cc:252] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 9.2, Runtime API Version: 9.0
W0803 02:18:52.512806 2067 device_context.cc:260] device: 0, cuDNN Version: 7.3.
<<load model success>>
reference-guided synthesis
100%|█████████████████████████████████████████████| 6/6 [00:00<00:00, 93.16it/s]
9it [00:00, 107.21it/s]
latent-guided synthesis
100%|█████████████████████████████████████████████| 6/6 [00:03<00:00, 1.98it/s]
[*] Test finished!