判别性别的python代码_UFDN论文和代码分析

论文题目:

A Unified Feature Disentangler for Multi-Domain Image Translation and Manipulation

论文:

A Unified Feature Disentangler for Multi-Domain Image Translation and Manipulation​arxiv.org

代码:

Alexander-H-Liu/UFDN​github.com

该篇文章发表于2018NIPS

论文的不同之处:可以在多个域上进行变换,不同的域对应的图像在潜在空间上有同一个表示,即文中所说的域不变表示。当在这个 潜在空间上获得域不变的表示时,再结合域的编码信息就能生成自己想要那个图像域的图像吧。

一个基本模型: Self-supervised feature disentanglement

2e2a19c35f1d0a804eb911e4a6878b78.png

7f0aa28a9ca7280c211525ec184d4594.png

简单的说,学习多个域的解开表示,常用的方法是使用VAE(由一个编码器E和一个生成器G组成)。但该方法存在的不足,就是通过潜在编码z可以生成原先的图像,但该潜在编码难以保证不包含域的信息,即将不同域的信息从z中分离开来,z只包含域不变的信息。

整个模型图如下:

b6fc2a741d7dbdc3f9b9d98b86bccfcb.png

先简单做一下说明:z是域不变编码,Dv是域判别器,原本z不带有域的信息,通过Dv实现这个目的。但这样生成的图像容易模糊,G和Dx进一步挺高生成图像质量和保证z不带有域的信息。

作者提出的解决方法:使用对抗域分类器(adversarial domain classification)

这种思想在Diverse Image-to-Image Translation via Disentangled Representations这篇文章也有类似的想法,但具体上不同。在这篇文章中,作者采用的是训练一个 域判别器Dv,Dv的目标是准确分出每一个潜在编码是属于哪一个域,而域的编码可以使用一个独热编码或是多个独热编码拼接的方式。关于独热编码的部分,作者提供的代码是这样写的:

domain_code = np.concatenate([np.repeat(np.array([[*([1]*int(code_dim/3)),
                                                   *([0]*int(code_dim/3)),
                                                   *([0]*int(code_dim/3))]]),batch_size,axis=0),
                              np.repeat(np.array([[*([0]*int(code_dim/3)),
                                                   *([1]*int(code_dim/3)),
                                                   *([0]*int(code_dim/3))]]),batch_size,axis=0),
                              np.repeat(np.array([[*([0]*int(code_dim/3)),
                                                   *([0]*int(code_dim/3)),
                                                   *([1]*int(code_dim/3))]]),batch_size,axis=0)],
                              axis=0)

另外,关于域判别器的,输入的是潜在表示z,生成的是对每个域的概率分布lv

59edcbeab331099f4ff8f426a3038db9.png

而编码器E是迷惑判别器Dv,目标是:

249f3851e2fc054283f4c16aa741e08b.png

这部分对应的代码如下:

# Train Feature Discriminator  ##跟新潜在空间的判别器
        opt_df.zero_grad()   ###
        
        enc_x = vae(input_img,return_enc=True).detach()  ##这里只是返回潜在别的编码,对应文章中的z
        code_pred = d_feat(enc_x)  ##这个是全连接层,好几个全连接层,在潜在编码上对潜在编码进行分类,对应文章中的vc

        df_loss = clf_loss(code_pred,code)  ##这个是域判别器,判断样本来自哪一个域,希望可以正确判断出来,和后面的VAE形成对抗学习
        df_loss.backward()
     
        opt_df.step()  ##域判别器损失  
其中:
invert_code = 1-code  ##相反码

        ### Feature space adversarial Phase       
        enc_x = vae(input_img,return_enc=True)
        domain_pred = d_feat(enc_x)
        adv_code_loss = clf_loss(domain_pred,invert_code)  ##这一步是想VAE生成的潜在编码与相反的域编码相关性也很大,但在前面时还要与对应的域编码相关性很大,构成一种对抗学习的关系
        
        feature_loss = loss_lambda['feat_domain']['cur']*adv_code_loss
        feature_loss.backward()  

通过对抗学习后使的潜在编码不包含域的信息,但在解决上面的问题后,还存在的一个问题是:通过VAE方法生成的图像比较模糊,难以保证生成的图像是高质量的,为解决这个问题,作者还提出了一个图像判别器Dx,Dx的作用有两个:(1)提高生成的图像质量和进一步解开潜在编码的鱼信息。

3808833a8eb4fb688e021f192dc38509.png

首先第一个是对抗学习,使生成的图像真实,第二个是分类的损失,单独的z难以分辨出是属于哪一个域的,结合上域编码后(insert_attrs=trans_code)就可以分类出属于哪一个域,所示进一步加强了潜在空间不包含域的信息。

这部分的代码:

        # Train Pixel Discriminator
        opt_dp.zero_grad()
        
        pix_real_pred,pix_real_code_pred = d_pix(input_img)  ##是不是真实,来自哪一个域
        
        fake_img = vae(input_img,insert_attrs=trans_code)[0].detach()  ##重新生成的假样本
        pix_fake_pred, _  = d_pix(fake_img)  ##pix space上的对抗学习
        
        pix_real_pred = pix_real_pred.mean()
        pix_fake_pred = pix_fake_pred.mean()

        gp = loss_lambda['gp']['cur']*calc_gradient_penalty(d_pix,input_img.data,fake_img.data)   ###WGAN中的梯度损失
        pix_code_loss = clf_loss(pix_real_code_pred,code)
        
        d_pix_loss = pix_code_loss + pix_fake_pred - pix_real_pred + gp  ##d_pix判别器希望可以正确将样本所属的域判别出来,同时判别出真假样本(将假样本判别为0,真样本判别为1)
        d_pix_loss.backward()
        
        opt_dp.step()


        ### Pixel space adversarial Phase
        enc_x = vae(input_img,return_enc=True).detach()  ##返回潜在编码
        
        fake_img = vae.decode(enc_x,trans_code)  ##返回重构图像
        recon_enc_x = vae(fake_img,return_enc=True)  ##返回重构图像再编码
        adv_pix_loss, pix_code_pred = d_pix(fake_img)   ##重构图像的判断真假和判断来自哪一个域
        adv_pix_loss = adv_pix_loss.mean()  ##判断真假的均值 
        pix_clf_loss = clf_loss(pix_code_pred,trans_code)  ##判断域的损失
        
        
        pixel_loss =  - loss_lambda['pix_adv']['cur']*adv_pix_loss + loss_lambda['pix_clf']['cur']*pix_clf_loss ##vae希望成的图像能分辨出域但希望可以骗过域判别器,这一步其实就是生成假样本和真样本对抗训练pix_clf分辨真假样本的能力,同时无论真假样本都要分辨出域信息,这样可以保证潜在表示不包含有任何的域相关的信息
        pixel_loss.backward()
        
        opt_vae.step()


另外重构部分的代码如下:
        ### Reconstruction Phase
        recon_batch, mu, logvar = vae(input_img,insert_attrs = code)   ##这里的code就是域码,就是所属于哪一个图像域
        mse,kl = vae_loss(recon_batch, input_img, mu, logvar, reconstruct_loss)  #.view(batch_size,-1)
        recon_loss = (loss_lambda['pix_recon']['cur']*mse+loss_lambda['kl']['cur']*kl)
        recon_loss.backward()   ##减少重构误差和KL散度

自己的总结:

(1)使用了adversarial domain classification,在潜在空间上得到域不变编码;

(2)图像判别器Dx是进一步对潜在编码进行解开,同时提高生成图像的质量。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值