CoGAN

文章的思想是,利用网络层的权重共享约束,训练GAN网络.模型包括两个生成网络,两个判别网络,

img

训练数据为不成对的两个域Domain1,Domain2的图片,我们希望的是训练的两个生成网络g1,g2能够在输入向量z相同的情况下,生成的图片高频信息相同,低频信息不同.因此在觉得高频特征的生成网络的前几层,将两个生成网络的权重共享,并且,将两个判别网络f1,f2的最后几层网络权重共享,如上图所示.

github代码为:https://github.com/andrewliao11/CoGAN-tensorflow

两个生成网络,判别网络的结构相同,通过输入参数share_params控制权重是否共享.

生成网络代码,

def generator(self, z, y=None, share_params=False, reuse=False, name='G'):

    if '1' in name:
            branch = '1'
        elif '2' in name:
            branch = '2'

    # layers that share the variables 
        s = self.output_size
        s2, s4 = int(s/2), int(s/4) 
        h0 = prelu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin', reuse=share_params), reuse=share_params), 
                        name='g_h0_prelu', reuse=share_params)

        h1 = prelu(self.g_bn1(linear(z, self.gf_dim*2*s4*s4,'g_h1_lin',reuse=share_params),reuse=share_params),
                        name='g_h1_prelu', reuse=share_params)
        h1 = tf.reshape(h1, [self.batch_size, s4, s4, self.gf_dim * 2])

        h2 = prelu(self.g_bn2(deconv2d(h1, [self.batch_size,s2,s2,self.gf_dim * 2], 
            name='g_h2', reuse=share_params), reuse=share_params), name='g_h2_prelu', reuse=share_params)

    # layers that don't share the variable
    with tf.variable_scope(name):
        if reuse:
        tf.get_variable_scope().reuse_variables()
        output = tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s, s, self.c_dim], name='g'+branch+'_h3', reuse=False))

        return output

判别网络代码,

 def discriminator(self, image, y=None, share_params=False, reuse=False, name='D'):

    # select the corresponding batchnorm1(not shared)
        if '1' in name:
            d_bn1 = self.d1_bn1
        branch = '1'
        elif '2' in name:
            d_bn1 = self.d2_bn1
        branch = '2'

        # layers that don't share variable
    with tf.variable_scope(name):
        if reuse:
        tf.get_variable_scope().reuse_variables()

            h0 = prelu(conv2d(image, self.c_dim, name='d'+branch+'_h0_conv', reuse=False), 
                    name='d'+branch+'_h0_prelu', reuse=False)

            h1 = prelu(d_bn1(conv2d(h0, self.df_dim, name='d'+branch+'_h1_conv', reuse=False), reuse=reuse), 
                    name='d'+branch+'_h1_prelu', reuse=False)
            h1 = tf.reshape(h1, [self.batch_size, -1])            

        # layers that share variables
        h2 = prelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin', reuse=share_params),reuse=share_params), 
                    name='d_h2_prelu', reuse=share_params)

        h3 = linear(h2, 1, 'd_h3_lin', reuse=share_params)

        return tf.nn.sigmoid(h3), h3

输入向量z(噪声向量),以及attribute向量y(头发颜色,年龄等特征向量),生成图片,

    # input of the generator is the concat of z, y
        self.G1 = self.generator(self.z, self.y, share_params=False, reuse=False, name='G1')
    self.G2 = self.generator(self.z, self.y, share_params=True, reuse=False, name='G2')

两个域的输入图像(real),以及生成图像(fake)分别输入判别网络,

#input the real images
        self.D1_logits, self.D1 = self.discriminator(self.images1, self.y, share_params=False, reuse=False, name='D1')
    self.D2_logits, self.D2 = self.discriminator(self.images2, self.y, share_params=True, reuse=False, name='D2')
# input the fake images
        self.D1_logits_, self.D1_ = self.discriminator(self.G1, self.y, share_params=True, reuse=True, name='D1')
    self.D2_logits_, self.D2_ = self.discriminator(self.G2, self.y, share_params=True, reuse=True, name='D2')

在损失函数中加入了权重参数,

GAN1损失函数,

        self.d1_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits, tf.ones_like(self.D1)*0.9))
        self.d1_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits_,tf.ones_like(self.D1_)*0.1))
        self.g1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits_, tf.ones_like(self.D1_)*0.9))

    self.d1_loss = self.d1_loss_real + self.d1_loss_fake

GAN2损失函数,


        self.d2_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits, tf.ones_like(self.D2)*0.9))
        self.d2_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits_,tf.ones_like(self.D2_)*0.1))
        self.g2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits_, tf.ones_like(self.D2_)*0.9))
        self.d2_loss = self.d2_loss_real + self.d2_loss_fake

试验效果

手写字体,上下两行分别为生成网络G1,G2的生成效果,

img

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值