CycleGAN原理
1. 初识CycleGAN
1.1 什么是CycleGAN
-
CycleGAN是一种完成 图 像 到 图 像 \color{red}图像到图像 图像到图像的转换的一种GAN。
-
图像到图像的转换是一类视觉和图形问题,其 目 标 是 获 得 输 入 图 像 和 输 出 图 像 之 间 的 映 射 \color{red}目标是获得输入图像和输出图像之间的映射 目标是获得输入图像和输出图像之间的映射。
-
但是,对于许多任务,配对的训练数据是获取不了的(比如:相同动作下的马和斑马)。
-
CycleGAN提出了一种 在 没 有 成 对 数 据 的 情 况 下 , 学 习 将 图 像 从 源 域 X 转 换 到 目 标 域 Y 的 方 法 \color{red}在没有成对数据的情况下,学习将图像从源域X转换到目标域Y的方法 在没有成对数据的情况下,学习将图像从源域X转换到目标域Y的方法。
1.2 数学表示
- 两个图像域:
A
\color{blue}A
A、
B
\color{blue}B
B;
- A A A:假设是马的数据集;
- B B B:假设是斑马的数据集。
- 两个单射(生成器):
G
A
2
B
:
A
→
B
\color{blue}G_{A2B} : A \rightarrow B
GA2B:A→B,
G
B
2
A
:
B
→
A
\color{blue}G_{B2A}: B \rightarrow A
GB2A:B→A ;
- 生成器 G A 2 B : A → B \color{blue}G_{A2B}: A \rightarrow B GA2B:A→B:将图像从 A A A 转换为 B B B(例如马到斑马)
- 生成器 G B 2 A : B → A \color{blue}G_{B2A}: B \rightarrow A GB2A:B→A:将图像从 B B B 转换为 A A A(例如斑马到马)
这两个映射 是 双 射 \color{red}是双射 是双射。这是通过循环一致性损失来实现的:
- G B 2 A ( G A 2 B ( A ) ) ≈ A \color{blue}G_{B2A}\left(G_{A2B} \left(A\right)\right) \approx A GB2A(GA2B(A))≈A
- G A 2 B ( G B 2 A ( B ) ) ≈ B \color{blue}G_{A2B} \left(G_{B2A}\left(B\right)\right) \approx B GA2B(GB2A(B))≈B
- 两个判别器:
D
B
\color{blue}D_{B}
DB、
D
A
\color{blue}D_{A}
DA
- 判别器 D A \color{blue}D_A DA:评分 A A A 的图像看起来有多真实(例如,这个图像看起来像一匹马吗?)
- 判别器 D B \color{blue}D_B DB:对 B B B 图像的真实程度打分(例如,这张图像看起来像斑马吗?)
循 环 一 致 性 \color{red}循环一致性 循环一致性是:
- 如果你能够训练这对 GAN 从 A → B → A \color{blue}A \rightarrow B \rightarrow A A→B→A 转换,即在确保循环一致性的同时生成图像;
- 那么 A → G ( A ) → F ( G ( A ) ) ≈ A \color{blue}A \rightarrow G(A) \rightarrow F(G(A)) \approx A A→G(A)→F(G(A))≈A,那么你就可以很好地学习图像Translate任务了。
2. 模型介绍
2.1 基本过程
以
G
A
2
B
\color{blue}G_{A2B}
GA2B和
D
B
\color{blue}D_B
DB为例。
2.2 生成器
生成器由三部分组成:
编
码
器
\color{red}编码器
编码器、
转
换
器
\color{red}转换器
转换器、
解
码
器
\color{red}解码器
解码器。
2.2.1 编码器
第一步是利用卷积网络从输入图像中提取特征。整个编码过程,将
D
A
\color{blue}D_A
DA 域中一个尺寸为 [256,256,3]
的图像,输入到设计的编码器中,获得了尺寸为 [64,64,256]
的输出
O
A
e
n
c
\color{blue}O_{Aenc}
OAenc。
2.2.2 转换器
-
这些网络层的作用是组合图像的不同相近特征,然后基于这些特征,确定如何将图像的特征向量 O A e n c \color{blue}O_{Aenc} OAenc 从 D A \color{blue}D_A DA域转换为 D B \color{blue}D_B DB域的特征向量。
-
作者使用 6 层 R e s n e t \color{red}Resnet Resnet模块。
- 一个 Resnet 模块是一个由两个卷积层组成的神经网络层,其中部分输入数据直接添加到输出。
这样做是为了确保先前网络层的输入数据信息直接作用于后面的网络层,使得相应输出与原始输入的偏差缩小,否则原始图像的特征将不会保留在输出中且输出结果会偏离目标轮廓。
- 这个任务的一个主要目标是保留原始图像的特征,如目标的大小和形状,因此残差网络非常适合完成这些转换。Resnet 模块的结构如下所示:
- 一个 Resnet 模块是一个由两个卷积层组成的神经网络层,其中部分输入数据直接添加到输出。
-
O B e n c \color{blue}O_{Benc} OBenc表示该层的最终输出,尺寸为
[64,64,256]
,这可以看作是 D B \color{blue}D_B DB域中图像的特征向量。
2.2.3 解码器
- 解码过程与编码方式完全相反,从特征向量中还原出低级特征,这是利用了 反 卷 积 层 ( d e c o n v o l u t i o n ) \color{red}反卷积层(deconvolution) 反卷积层(deconvolution)来完成的。
- 将这些低级特征转换得到一张在
D
B
\color{blue}D_B
DB域中的图像,得到一个大小为
[256,256,3]
的生成图像 G e n B \color{blue}Gen_B GenB。
2.3 判别器
判别器将一张图像作为输入,并尝试预测其为原始图像或是生成器的输出图像。判别器的结构如下所示:
判别器本身就属于卷积网络,需要从图像中提取特征;然后是确定这些特征是否属于该特定类别,使用一个产生一维输出的卷积层来完成这个任务。
2.4 Loss计算
Loss分为6个部分,可分为3类:
2.4.1 原始GAN的损失:
- 对于判别器
D
B
(
线
1
)
\color{blue}D_B(线1)
DB(线1):
L G A N ( G A 2 B , D B , A , B ) = E b ∈ P B log D B ( b ) + E a ∈ P A log [ 1 − D B ( G A 2 B ( a ) ) ] (2.4.1) \color{red}\mathcal{L}_{GAN}\left(G_{A2B}, D_{B}, A, B\right) = \mathbb{E}_{b \in \mathbb{P}_B} \log D_B(b) + \mathbb{E}_{a \in \mathbb{P}_{A}} \log[1-D_B(G_{A2B}(a))]\tag{2.4.1} LGAN(GA2B,DB,A,B)=Eb∈PBlogDB(b)+Ea∈PAlog[1−DB(GA2B(a))](2.4.1)- A到B的判别网络loss(
D
B
D_{B}
DB主要是判别
l
o
s
s
f
a
k
e
loss_{fake}
lossfake),生成网络loss(
G
A
2
B
G_{A2B}
GA2Bloss):
- i m g A → G A 2 B ( i m g A ) → f a k e B → D B ( f a k e B ) → v a l i d B img_A \rightarrow G_{A2B}(img_A) \rightarrow fake_B \rightarrow D_B(fake_B) \rightarrow valid_B imgA→GA2B(imgA)→fakeB→DB(fakeB)→validB
- 输入 i m g A img_A imgA,输出 v a l i d B valid_B validB,判别网络 D B D_B DB目标 f a k e fake fake,生成网络 G A 2 B G_{A2B} GA2B目标 v a l i d valid valid。
- 真实数据B的鉴别网络loss(
D
B
D_{B}
DB主要是判别
l
o
s
s
r
e
a
l
loss_{real}
lossreal):
- i m g B → D B ( i m g B ) → v a l i d B img_B \rightarrow D_B(img_B) \rightarrow valid_B imgB→DB(imgB)→validB
- 输入 i m g B img_B imgB,输出 v a l i d B valid_B validB,鉴别网络 D B D_B DB目标 v a l i d valid valid。
- A到B的判别网络loss(
D
B
D_{B}
DB主要是判别
l
o
s
s
f
a
k
e
loss_{fake}
lossfake),生成网络loss(
G
A
2
B
G_{A2B}
GA2Bloss):
- 对于判别器
D
A
(
线
2
)
\color{blue}D_A(线2)
DA(线2):
L G A N ( G B 2 A , D A , B , A ) = E a ∈ P A log D A ( a ) + E b ∈ P B log [ 1 − D A ( G B 2 A ( b ) ) ] (2.4.2) \color{red}\mathcal{L}_{GAN}\left(G_{B2A}, D_{A}, B, A\right) = \mathbb{E}_{a \in \mathbb{P}_A} \log D_A(a) + \mathbb{E}_{b \in \mathbb{P}_{B}} \log[1-D_A(G_{B2A}(b))]\tag{2.4.2} LGAN(GB2A,DA,B,A)=Ea∈PAlogDA(a)+Eb∈PBlog[1−DA(GB2A(b))](2.4.2)解释同上。
2.4.2 Cycle一致性损失:
- 对于生成器
G
B
2
A
(
线
6
)
\color{blue}G_{B2A}(线6)
GB2A(线6)和生成器
G
A
2
B
(
线
5
)
\color{blue}G_{A2B}(线5)
GA2B(线5):
L c y c ( G A 2 B , G B 2 A ) = E a ∼ p d a t a ( A ) [ ∣ ∣ G B 2 A ( G A 2 B ( a ) ) − a ∣ ∣ 1 ] + E b ∼ p d a t a ( B ) [ ∣ ∣ G A 2 B ( G B 2 A ( b ) ) − b ∣ ∣ 1 ] (2.4.3) \color{red}\begin{array}{ll}\mathcal{L}_{cyc}\left(G_{A2B}, G_{B2A}\right) =& \mathbb{E}_{a \sim p_{data}\left(A\right)}\left[||G_{B2A}\left(G_{A2B}\left(a\right)\right) - a||_{1}\right] \\ &+ \mathbb{E}_{b \sim p_{data}\left(B\right)}\left[||G_{A2B}\left(G_{B2A}\left(b\right)\right) - b||_{1}\right]\end{array}\tag{2.4.3} Lcyc(GA2B,GB2A)=Ea∼pdata(A)[∣∣GB2A(GA2B(a))−a∣∣1]+Eb∼pdata(B)[∣∣GA2B(GB2A(b))−b∣∣1](2.4.3)- 生成网络loss(
G
B
2
A
(
G
A
2
B
)
G_{B2A}(G_{A2B})
GB2A(GA2B) loss):
- i m g A → G A 2 B ( i m g A ) → f a k e B → G B 2 A ( f a k e B ) → r e c A img_A \rightarrow G_{A2B}(img_A) \rightarrow fake_B \rightarrow G_{B2A}(fake_B) \rightarrow rec_A imgA→GA2B(imgA)→fakeB→GB2A(fakeB)→recA;
- 输入 i m g A img_A imgA,输出 r e c A rec_A recA,生成网络 G A 2 B → G B 2 A G_{A2B} \rightarrow G_{B2A} GA2B→GB2A目标 i m g s A imgs_A imgsA;
- 生成网络loss( G A 2 B ( G B 2 A ) G_{A2B}(G_{B2A}) GA2B(GB2A) loss)解释同上
- 生成网络loss(
G
B
2
A
(
G
A
2
B
)
G_{B2A}(G_{A2B})
GB2A(GA2B) loss):
2.4.3 Identity映射损失:
- 对于生成器
G
B
2
A
(
线
3
)
\color{blue}G_{B2A}(线3)
GB2A(线3)和生成器
G
A
2
B
(
线
4
)
\color{blue}G_{A2B}(线4)
GA2B(线4):
L I d e n t i t y ( G A 2 B , G B 2 A ) = E b ∼ p d a t a ( B ) [ ∣ ∣ G A 2 B ( b ) − b ∣ ∣ 1 ] + E a ∼ p d a t a ( A ) [ ∣ ∣ G B 2 A ( a ) − a ∣ ∣ 1 ] (2.4.4) \color{red}\mathcal{L}_{Identity}\left(G_{A2B}, G_{B2A}\right) = \mathbb{E}_{b \sim p_{data}\left(B\right)}\left[||G_{A2B}\left(b\right) - b||_{1}\right] + \mathbb{E}_{a \sim p_{data}\left(A\right)}\left[||G_{B2A}\left(a\right) - a||_{1}\right]\tag{2.4.4} LIdentity(GA2B,GB2A)=Eb∼pdata(B)[∣∣GA2B(b)−b∣∣1]+Ea∼pdata(A)[∣∣GB2A(a)−a∣∣1](2.4.4)- 生成网络loss(
G
B
2
A
G_{B2A}
GB2A Ident_loss):
- i m g A → G B 2 A → i m g A i d img_A \rightarrow G_{B2A} \rightarrow {img_A}_{id} imgA→GB2A→imgAid
- 输入 i m g A img_A imgA,输出 i m g A i d {img_A}_{id} imgAid,目标 i m g s A imgs_A imgsA;
- 生成网络loss(
G
A
2
B
G_{A2B}
GA2B Ident_loss):
- i m g B → G A 2 B → i m g B i d img_B \rightarrow G_{A2B} \rightarrow {img_B}_{id} imgB→GA2B→imgBid
- 输入 i m g B img_B imgB,输出 i m g B i d {img_B}_{id} imgBid,目标 i m g B img_B imgB 。
- 生成网络loss(
G
B
2
A
G_{B2A}
GB2A Ident_loss):
2.4.4 整体损失
- 整体损失可以写成:
L G A N ( G A 2 B , G B 2 A , D A , D B ) = L G A N ( G A 2 B , D B , A , B ) + L G A N ( G B 2 A , D A , B , A ) + λ c y c L c y c ( G A 2 B , G B 2 A ) + λ i d L I d e n t i t y ( G A 2 B , G B 2 A ) (2.4.5) \color{red}\begin{array}{ll} \mathcal{L}_{GAN}\left(G_{A2B}, G_{B2A}, D_{A}, D_{B}\right) &=\mathcal{L}_{GAN}\left(G_{A2B}, D_{B}, A, B\right) + \mathcal{L}_{GAN}\left(G_{B2A}, D_{A}, B, A\right) \\& +\lambda_{cyc}\mathcal{L}_{cyc}\left(G_{A2B}, G_{B2A}\right) + \lambda_{id} \mathcal{L}_{Identity}\left(G_{A2B}, G_{B2A}\right)\end{array}\tag{2.4.5} LGAN(GA2B,GB2A,DA,DB)=LGAN(GA2B,DB,A,B)+LGAN(GB2A,DA,B,A)+λcycLcyc(GA2B,GB2A)+λidLIdentity(GA2B,GB2A)(2.4.5) - 我们需要求解:
G A 2 B ∗ , G B 2 A ∗ = arg min G A 2 B , G B 2 A min D A , D B L G A N ( G A 2 B , G B 2 A , D A , D B ) (2.4.6) \color{red}G_{A2B}^{*}, G_{B2A}^{*} = \arg \min_{G_{A2B}, G_{B2A}} \min_{D_{A}, D_{B}} \mathcal{L}_{GAN}\left(G_{A2B}, G_{B2A}, D_{A}, D_{B}\right)\tag{2.4.6} GA2B∗,GB2A∗=argGA2B,GB2AminDA,DBminLGAN(GA2B,GB2A,DA,DB)(2.4.6) - 对于原始架构,作者使用:
- 对于生成网络:两个 stride-2 卷积、几个残差块和两个带 stride 的分数步长卷积
- 对于生成网络:instance normalization
- 对于判别器:用PatchGAN
- GAN 目标的最小二乘损失。
3. 讨论
3.1 去掉重构误差?模型是否还有效?
模型仍然有效,只是收敛比较慢,毕竟缺少了重构误差这样的强引导信息。以及,虽然实现了风格迁移,但是人物的一些属性改变了,比如可能出现『变性』、『变脸』,而姿态在转换的时候一般不出现错误。这表明: 对 偶 重 构 误 差 能 够 引 导 模 型 在 迁 移 的 时 候 保 留 图 像 固 有 的 属 性 \color{red}对偶重构误差能够引导模型在迁移的时候保留图像固有的属性 对偶重构误差能够引导模型在迁移的时候保留图像固有的属性; 而 对 抗 l o s s 则 负 责 确 定 模 型 该 学 什 么 , 该 怎 么 迁 移 \color{red}而对抗loss则负责确定模型该学什么,该怎么迁移 而对抗loss则负责确定模型该学什么,该怎么迁移。
3.2 GAN(generative adversarial network)的生成模型为什么不直接用VAE(variational autoencoder)或者AE(autoencoder)?
- AE很难生成样本的原因是它对隐变量空间没有限制,很可能编码空间仍是一个非线性带边界的空间,随机取样时很有可能并不在编码空间内。VAE则限制了隐变量和先验的关系,取样时更有可能在解码器的“定义域”内。
- 参考:1. Adversarial Autoencoders
3.3 为什么Cycle一致性损失和 Identity映射损失要用l1范式?
因为:
l
1
正
则
是
稀
疏
作
用
,
先
验
分
布
是
L
a
p
l
a
c
e
分
布
;
l
2
正
则
是
绝
对
值
最
小
,
先
验
分
布
是
G
a
u
s
s
i
a
n
分
布
\color{red}l_1正则是稀疏作用,先验分布是Laplace分布;l_2正则是绝对值最小,先验分布是Gaussian分布
l1正则是稀疏作用,先验分布是Laplace分布;l2正则是绝对值最小,先验分布是Gaussian分布。