Multimodal Unsupervised Image-to-Image Translation
Abstract
- 给定源域中的图像,目标是学习目标域中相应图像的条件分布,而不会看到相应图像对的任何示例
- 图像可以分解为内容和样式
- 这个模型可以让用户控制样式,用户只需提供一张图即可
I. Introduction
- 这个方法还能用于超分辨率[1],着色[2],修复[3],属性转移[4]和样式转移[5]
- 现有的一些技术,都是确定性的单模态的映射。因此,无法捕捉可能输出的整体分布。(不过,对于重构类的问题,单模态映射其实是ok的。然而,对于encoder-decoder 框架下的高维数据映射通过表示学习映射到低维数据,再decode低维数据到高维,得到的结果不应该是唯一的。所以,重构类问题其实也可以参考这篇文章。)
- 这篇文章做出来以下假设:
- 图像的潜在空间可以分解为内容空间和样式空间
- 不同域中的图像共享一个共同的内容空间,但不共享样式空间
- 文章想做的事保留内容空间,然后样式空间可以保留随机性。
- 实验结果优于state-of-the-art
- 这个模型可以让用户控制样式,用户只需提供一张图即可
II. Related Works
- GANs
- Image2Image Isola[6] - 多域,多模态transfer
- Style transfer
- Learning disentangle representations: InfoGAN and β − V A E \beta-VAE β−VAE
III. Multimodal Unsupervised Image-to-image Translation
假设
使 x 1 ∈ X 1 x_{1} \in \mathcal{X}_{1} x1∈X1 and x 2 ∈ X 2 x_{2} \in \mathcal{X}_{2} x2∈X2来自不同的图片域。 在非监督图到图转换背景下,我们给定从两个边缘分布 p ( x 1 ) p\left(x_{1}\right) p(x1)和 p ( x 2 ) p\left(x_{2}\right) p(x2)中得到的图,而联合分布 p ( x 1 , x 2 ) p\left(x_{1},x_{2}\right) p(x1,x2)是未知的。于是,我们的目标就是估计两个条件分布 p ( x 2 ∣ x 1 ) p\left(x_{2} | x_{1}\right) p(x2∣x1)和 p ( x 1 ∣ x 2 ) p\left(x_{1} | x_{2}\right) p(x1∣x2),通过学习图到图转换的模型。
不同域中的图像共享一个共同的内容空间,但不共享样式空间。
一对相关的图片
(
x
1
,
x
2
)
(x_1, x_2)
(x1,x2)是来自于联合分布,
x
1
=
G
1
∗
(
c
,
s
1
)
x_1=G^{*}_{1}(c,s_1)
x1=G1∗(c,s1) and
x
2
=
G
2
∗
(
c
,
s
2
)
x_2=G^{*}_{2}(c,s_2)
x2=G2∗(c,s2).
模型
- E 是 encoder, G 是 decoder, D 是 discriminator
min E 1 , E 2 , G 1 , G 2 max D 1 , D 2 L ( E 1 , E 2 , G 1 , G 2 , D 1 , D 2 ) = L G A N x 1 + L G A N x 2 + λ x ( L r e c o n x 1 + L r e c o n x 2 ) + λ c ( L r e c o n c 1 + L r e c o n c 2 ) + λ s ( L r e c o n s 2 + L r e c o n s 2 ) \begin{array}{l}{\min _{E_{1}, E_{2}, G_{1}, G_{2}} \max _{D_{1}, D_{2}} \mathcal{L}\left(E_{1}, E_{2}, G_{1}, G_{2}, D_{1}, D_{2}\right)=\mathcal{L}_{\mathrm{GAN}}^{x_{1}}+\mathcal{L}_{\mathrm{GAN}}^{x_{2}}+} \\ {\lambda_{x}\left(\mathcal{L}_{\mathrm{recon}}^{x_{1}}+\mathcal{L}_{\mathrm{recon}}^{x_{2}}\right)+\lambda_{c}\left(\mathcal{L}_{\mathrm{recon}}^{c_{1}}+\mathcal{L}_{\mathrm{recon}}^{c_{2}}\right)+\lambda_{s}\left(\mathcal{L}_{\mathrm{recon}}^{s_{2}}+\mathcal{L}_{\mathrm{recon}}^{s_{2}}\right)}\end{array} minE1,E2,G1,G2maxD1,D2L(E1,E2,G1,G2,D1,D2)=LGANx1+LGANx2+λx(Lreconx1+Lreconx2)+λc(Lreconc1+Lreconc2)+λs(Lrecons2+Lrecons2)
用GAN去学encoder decoder 以及discriminator,目标是学encoder 和 decoder。
Loss 包括:
- 对抗性损失-GAN
- 图像重构误差- x x x recon
- 特征重构误差- c , s c,s c,s recon
IV. Theoretical Analysis
- Encoder 和 Decoder 之间的关系对Loss 的影响
- Proposition 2
上述命题表明,在达到最优时,编码的样式分布与它们的高斯先验相匹配。 此外,编码的内容分布与生成时的分布匹配,这仅是来自其他域的编码分布。 这表明内容空间变为域不变。q是高斯先验分布
- Proposition 3
- Proposition 4
V. Experiments
- MLP
- AdaIN- 分布normalize
-
AdalN
(
x
,
y
)
=
σ
(
y
)
(
x
−
μ
(
x
)
σ
(
x
)
)
+
μ
(
y
)
\operatorname{AdalN}(x, y)=\sigma(y)\left(\frac{x-\mu(x)}{\sigma(x)}\right)+\mu(y)
AdalN(x,y)=σ(y)(σ(x)x−μ(x))+μ(y)
以下是引用了shaoanlu的Keras的代码
def op_adain(inp):
x = inp[0]
mean, var = tf.nn.moments(x, [1,2], keep_dims=True)
adain_bias = inp[1]
adain_bias = K.reshape(adain_bias, (-1, 1, 1, n_dim_adain))
adain_weight = inp[2]
adain_weight = K.reshape(adain_weight, (-1, 1, 1, n_dim_adain))
out = tf.nn.batch_normalization(x, mean, var, adain_bias, adain_weight, variance_epsilon=1e-7)
return out
def AdaptiveInstanceNorm2d(inp, adain_params, idx_adain):
assert inp.shape[-1] == n_dim_adain
x = inp
idx_head = idx_adain*2*n_dim_adain
adain_weight = Lambda(lambda x: x[:, idx_head:idx_head+n_dim_adain])(adain_params)
adain_bias = Lambda(lambda x: x[:, idx_head+n_dim_adain:idx_head+2*n_dim_adain])(adain_params)
out = Lambda(op_adain)([x, adain_bias, adain_weight])
return out
def res_block_adain(inp, f, adain_params, idx_adain):
x = inp
x = ReflectPadding2D(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init,
kernel_regularizer=regularizers.l2(w_l2), bias_regularizer=regularizers.l2(w_l2),
use_bias=False, padding="valid")(x)
x = Lambda(lambda x: AdaptiveInstanceNorm2d(x[0], x[1], idx_adain))([x, adain_params])
x = Activation('relu')(x)
x = ReflectPadding2D(x)
x = Conv2D(f, kernel_size=3, kernel_initializer=conv_init,
kernel_regularizer=regularizers.l2(w_l2), bias_regularizer=regularizers.l2(w_l2),
use_bias=False, padding="valid")(x)
x = Lambda(lambda x: AdaptiveInstanceNorm2d(x[0], x[1], idx_adain+1))([x, adain_params])
x = add([x, inp])
return x
这个写法确保了,x代表的是图片上的位置信息,相邻元素的一些位置关系是不变的,但是整体分布是投影到了y的分布中。
Training
训练就是用两个内容一样,样式不同的两组数据集来训练。
目标转换函数
def model_paths(netEnc_content, netEnc_style, netDec):
fn_content_code = K.function([netEnc_content.inputs[0]], [netEnc_content.outputs[0]])
fn_style_code = K.function([netEnc_style.inputs[0]], [netEnc_style.outputs[0]])
fn_decoder_rgb = K.function(netDec.inputs, [netDec.outputs[0]])
fake_output = netDec.outputs[0]
fn_decoder_out = K.function(netDec.inputs, [fake_output])
return fn_content_code, fn_style_code, fn_decoder_out
def translation(src_imgs, style_image, fn_content_code_src, fn_style_code_tar, fn_decoder_rgb_tar, rand_style=False):
# Cross domain translation function
# This funciton is for visualization purpose
"""
Inputs:
src_img: source domain images, shape=(input_batch_size, h, w, c)
style_image: target style images, shape=(input_batch_size, h, w, c)
fn_content_code_src: Source domain K.function of content encoder
fn_style_code_tar: Target domain K.function of style encoder
fn_decoder_rgb_tar: Target domain K.function of decoder
rand_style: sample style codes from normal distribution if set True.
Outputs:
fake_rgb: output tensor of decoder having chennels [R, G, B], shape=(input_batch_size, h, w, c)
"""
batch_size = src_imgs.shape[0]
content_code = fn_content_code_src([src_imgs])[0]
if rand_style:
style_code = np.random.normal(size=(batch_size, n_dim_style))
elif style_image is None:
style_code = fn_style_code_tar([src_imgs])[0]
else:
style_code = fn_style_code_tar([style_image])[0]
fake_rgb = fn_decoder_rgb_tar([style_code, content_code])[0]
return fake_rgb
path_content_code_A, path_style_code_A, path_decoder_A = model_paths(encoder_content_A, encoder_style_A, decoder_A)
path_content_code_B, path_style_code_B, path_decoder_B = model_paths(encoder_content_B, encoder_style_B, decoder_B)
使用的时候就是输入两张图,一张提供内容的图,一张提供style的图。也可以给定一张style图用于重新训练模型。