论文题目:
A Unified Feature Disentangler for Multi-Domain Image Translation and Manipulation
论文:
A Unified Feature Disentangler for Multi-Domain Image Translation and Manipulationarxiv.org代码:
Alexander-H-Liu/UFDNgithub.com该篇文章发表于2018NIPS
论文的不同之处:可以在多个域上进行变换,不同的域对应的图像在潜在空间上有同一个表示,即文中所说的域不变表示。当在这个 潜在空间上获得域不变的表示时,再结合域的编码信息就能生成自己想要那个图像域的图像吧。
一个基本模型: Self-supervised feature disentanglement
简单的说,学习多个域的解开表示,常用的方法是使用VAE(由一个编码器E和一个生成器G组成)。但该方法存在的不足,就是通过潜在编码z可以生成原先的图像,但该潜在编码难以保证不包含域的信息,即将不同域的信息从z中分离开来,z只包含域不变的信息。
整个模型图如下:
先简单做一下说明:z是域不变编码,Dv是域判别器,原本z不带有域的信息,通过Dv实现这个目的。但这样生成的图像容易模糊,G和Dx进一步挺高生成图像质量和保证z不带有域的信息。
作者提出的解决方法:使用对抗域分类器(adversarial domain classification)
这种思想在Diverse Image-to-Image Translation via Disentangled Representations这篇文章也有类似的想法,但具体上不同。在这篇文章中,作者采用的是训练一个 域判别器Dv,Dv的目标是准确分出每一个潜在编码是属于哪一个域,而域的编码可以使用一个独热编码或是多个独热编码拼接的方式。关于独热编码的部分,作者提供的代码是这样写的:
domain_code = np.concatenate([np.repeat(np.array([[*([1]*int(code_dim/3)),
*([0]*int(code_dim/3)),
*([0]*int(code_dim/3))]]),batch_size,axis=0),
np.repeat(np.array([[*([0]*int(code_dim/3)),
*([1]*int(code_dim/3)),
*([0]*int(code_dim/3))]]),batch_size,axis=0),
np.repeat(np.array([[*([0]*int(code_dim/3)),
*([0]*int(code_dim/3)),
*([1]*int(code_dim/3))]]),batch_size,axis=0)],
axis=0)
另外,关于域判别器的,输入的是潜在表示z,生成的是对每个域的概率分布lv
而编码器E是迷惑判别器Dv,目标是:
这部分对应的代码如下:
# Train Feature Discriminator ##跟新潜在空间的判别器
opt_df.zero_grad() ###
enc_x = vae(input_img,return_enc=True).detach() ##这里只是返回潜在别的编码,对应文章中的z
code_pred = d_feat(enc_x) ##这个是全连接层,好几个全连接层,在潜在编码上对潜在编码进行分类,对应文章中的vc
df_loss = clf_loss(code_pred,code) ##这个是域判别器,判断样本来自哪一个域,希望可以正确判断出来,和后面的VAE形成对抗学习
df_loss.backward()
opt_df.step() ##域判别器损失
其中:
invert_code = 1-code ##相反码
### Feature space adversarial Phase
enc_x = vae(input_img,return_enc=True)
domain_pred = d_feat(enc_x)
adv_code_loss = clf_loss(domain_pred,invert_code) ##这一步是想VAE生成的潜在编码与相反的域编码相关性也很大,但在前面时还要与对应的域编码相关性很大,构成一种对抗学习的关系
feature_loss = loss_lambda['feat_domain']['cur']*adv_code_loss
feature_loss.backward()
通过对抗学习后使的潜在编码不包含域的信息,但在解决上面的问题后,还存在的一个问题是:通过VAE方法生成的图像比较模糊,难以保证生成的图像是高质量的,为解决这个问题,作者还提出了一个图像判别器Dx,Dx的作用有两个:(1)提高生成的图像质量和进一步解开潜在编码的鱼信息。
首先第一个是对抗学习,使生成的图像真实,第二个是分类的损失,单独的z难以分辨出是属于哪一个域的,结合上域编码后(insert_attrs=trans_code)就可以分类出属于哪一个域,所示进一步加强了潜在空间不包含域的信息。
这部分的代码:
# Train Pixel Discriminator
opt_dp.zero_grad()
pix_real_pred,pix_real_code_pred = d_pix(input_img) ##是不是真实,来自哪一个域
fake_img = vae(input_img,insert_attrs=trans_code)[0].detach() ##重新生成的假样本
pix_fake_pred, _ = d_pix(fake_img) ##pix space上的对抗学习
pix_real_pred = pix_real_pred.mean()
pix_fake_pred = pix_fake_pred.mean()
gp = loss_lambda['gp']['cur']*calc_gradient_penalty(d_pix,input_img.data,fake_img.data) ###WGAN中的梯度损失
pix_code_loss = clf_loss(pix_real_code_pred,code)
d_pix_loss = pix_code_loss + pix_fake_pred - pix_real_pred + gp ##d_pix判别器希望可以正确将样本所属的域判别出来,同时判别出真假样本(将假样本判别为0,真样本判别为1)
d_pix_loss.backward()
opt_dp.step()
### Pixel space adversarial Phase
enc_x = vae(input_img,return_enc=True).detach() ##返回潜在编码
fake_img = vae.decode(enc_x,trans_code) ##返回重构图像
recon_enc_x = vae(fake_img,return_enc=True) ##返回重构图像再编码
adv_pix_loss, pix_code_pred = d_pix(fake_img) ##重构图像的判断真假和判断来自哪一个域
adv_pix_loss = adv_pix_loss.mean() ##判断真假的均值
pix_clf_loss = clf_loss(pix_code_pred,trans_code) ##判断域的损失
pixel_loss = - loss_lambda['pix_adv']['cur']*adv_pix_loss + loss_lambda['pix_clf']['cur']*pix_clf_loss ##vae希望成的图像能分辨出域但希望可以骗过域判别器,这一步其实就是生成假样本和真样本对抗训练pix_clf分辨真假样本的能力,同时无论真假样本都要分辨出域信息,这样可以保证潜在表示不包含有任何的域相关的信息
pixel_loss.backward()
opt_vae.step()
另外重构部分的代码如下:
### Reconstruction Phase
recon_batch, mu, logvar = vae(input_img,insert_attrs = code) ##这里的code就是域码,就是所属于哪一个图像域
mse,kl = vae_loss(recon_batch, input_img, mu, logvar, reconstruct_loss) #.view(batch_size,-1)
recon_loss = (loss_lambda['pix_recon']['cur']*mse+loss_lambda['kl']['cur']*kl)
recon_loss.backward() ##减少重构误差和KL散度
自己的总结:
(1)使用了adversarial domain classification,在潜在空间上得到域不变编码;
(2)图像判别器Dx是进一步对潜在编码进行解开,同时提高生成图像的质量。