文章的思想是,利用网络层的权重共享约束,训练GAN网络.模型包括两个生成网络,两个判别网络,
训练数据为不成对的两个域Domain1,Domain2的图片,我们希望的是训练的两个生成网络g1,g2能够在输入向量z相同的情况下,生成的图片高频信息相同,低频信息不同.因此在觉得高频特征的生成网络的前几层,将两个生成网络的权重共享,并且,将两个判别网络f1,f2的最后几层网络权重共享,如上图所示.
github代码为:https://github.com/andrewliao11/CoGAN-tensorflow
两个生成网络,判别网络的结构相同,通过输入参数share_params控制权重是否共享.
生成网络代码,
def generator(self, z, y=None, share_params=False, reuse=False, name='G'):
if '1' in name:
branch = '1'
elif '2' in name:
branch = '2'
# layers that share the variables
s = self.output_size
s2, s4 = int(s/2), int(s/4)
h0 = prelu(self.g_bn0(linear(z, self.gfc_dim, 'g_h0_lin', reuse=share_params), reuse=share_params),
name='g_h0_prelu', reuse=share_params)
h1 = prelu(self.g_bn1(linear(z, self.gf_dim*2*s4*s4,'g_h1_lin',reuse=share_params),reuse=share_params),
name='g_h1_prelu', reuse=share_params)
h1 = tf.reshape(h1, [self.batch_size, s4, s4, self.gf_dim * 2])
h2 = prelu(self.g_bn2(deconv2d(h1, [self.batch_size,s2,s2,self.gf_dim * 2],
name='g_h2', reuse=share_params), reuse=share_params), name='g_h2_prelu', reuse=share_params)
# layers that don't share the variable
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
output = tf.nn.sigmoid(deconv2d(h2, [self.batch_size, s, s, self.c_dim], name='g'+branch+'_h3', reuse=False))
return output
判别网络代码,
def discriminator(self, image, y=None, share_params=False, reuse=False, name='D'):
# select the corresponding batchnorm1(not shared)
if '1' in name:
d_bn1 = self.d1_bn1
branch = '1'
elif '2' in name:
d_bn1 = self.d2_bn1
branch = '2'
# layers that don't share variable
with tf.variable_scope(name):
if reuse:
tf.get_variable_scope().reuse_variables()
h0 = prelu(conv2d(image, self.c_dim, name='d'+branch+'_h0_conv', reuse=False),
name='d'+branch+'_h0_prelu', reuse=False)
h1 = prelu(d_bn1(conv2d(h0, self.df_dim, name='d'+branch+'_h1_conv', reuse=False), reuse=reuse),
name='d'+branch+'_h1_prelu', reuse=False)
h1 = tf.reshape(h1, [self.batch_size, -1])
# layers that share variables
h2 = prelu(self.d_bn2(linear(h1, self.dfc_dim, 'd_h2_lin', reuse=share_params),reuse=share_params),
name='d_h2_prelu', reuse=share_params)
h3 = linear(h2, 1, 'd_h3_lin', reuse=share_params)
return tf.nn.sigmoid(h3), h3
输入向量z(噪声向量),以及attribute向量y(头发颜色,年龄等特征向量),生成图片,
# input of the generator is the concat of z, y
self.G1 = self.generator(self.z, self.y, share_params=False, reuse=False, name='G1')
self.G2 = self.generator(self.z, self.y, share_params=True, reuse=False, name='G2')
两个域的输入图像(real),以及生成图像(fake)分别输入判别网络,
#input the real images
self.D1_logits, self.D1 = self.discriminator(self.images1, self.y, share_params=False, reuse=False, name='D1')
self.D2_logits, self.D2 = self.discriminator(self.images2, self.y, share_params=True, reuse=False, name='D2')
# input the fake images
self.D1_logits_, self.D1_ = self.discriminator(self.G1, self.y, share_params=True, reuse=True, name='D1')
self.D2_logits_, self.D2_ = self.discriminator(self.G2, self.y, share_params=True, reuse=True, name='D2')
在损失函数中加入了权重参数,
GAN1损失函数,
self.d1_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits, tf.ones_like(self.D1)*0.9))
self.d1_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits_,tf.ones_like(self.D1_)*0.1))
self.g1_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D1_logits_, tf.ones_like(self.D1_)*0.9))
self.d1_loss = self.d1_loss_real + self.d1_loss_fake
GAN2损失函数,
self.d2_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits, tf.ones_like(self.D2)*0.9))
self.d2_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits_,tf.ones_like(self.D2_)*0.1))
self.g2_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(self.D2_logits_, tf.ones_like(self.D2_)*0.9))
self.d2_loss = self.d2_loss_real + self.d2_loss_fake
试验效果
手写字体,上下两行分别为生成网络G1,G2的生成效果,