文章的思想是在对抗网络的基础上,引入年龄,性别等先验信息.
网络结构图如下:
对于输入图像,即input face,将其输入4个卷积层,stride=2,加一个全连接层,FC_1,得到输出特征向量z,z的维度为50.将特征向量输入generator G网络,得到输出生成图像.
网络中包含两个判别网络,discriminator z,用于判别输入维度50的向量z_prior,以及输入图像的encoder特征向量z.
另一个判别网络为discriminator image,用于判别生成图像和输入图像.
生成网络(generator G)
生成网络输入包括,输入图像特征向量z,输入图像的年龄向量,例如年龄有10个类别,输入图像为25岁,在第4个类别,则年龄向量为age=ones(1,10)*(-1),age(4)=1.另外还包括输入图像的性别向量gender.
# generator: z + label --> generated image
self.G = self.generator(
z=self.z,
y=self.age,
gender=self.gender,
enable_tile_label=self.enable_tile_label,
tile_ratio=self.tile_ratio
)
首先将特征向量z先后与年龄,性别向量串联,
def concat_label(x, label, duplicate=1):
x_shape = x.get_shape().as_list()
if duplicate < 1:
return x
# duplicate the label to enhance its effect, does it really affect the result?
label = tf.tile(label, [1, duplicate])
label_shape = label.get_shape().as_list()
if len(x_shape) == 2:
return tf.concat([x, label],1)
elif len(x_shape) == 4:
label = tf.reshape(label, [x_shape[0], 1, 1, label_shape[-1]])
return tf.concat( [x, label*tf.ones([x_shape[0], x_shape[1], x_shape[2], label_shape[-1]])],3)
if enable_tile_label: