Regression by Conditional Adversarial Autoencoder

文章的思想是在对抗网络的基础上,引入年龄,性别等先验信息.

网络结构图如下:

这里写图片描述
对于输入图像,即input face,将其输入4个卷积层,stride=2,加一个全连接层,FC_1,得到输出特征向量z,z的维度为50.将特征向量输入generator G网络,得到输出生成图像.

网络中包含两个判别网络,discriminator z,用于判别输入维度50的向量z_prior,以及输入图像的encoder特征向量z.

另一个判别网络为discriminator image,用于判别生成图像和输入图像.

生成网络(generator G)

生成网络输入包括,输入图像特征向量z,输入图像的年龄向量,例如年龄有10个类别,输入图像为25岁,在第4个类别,则年龄向量为age=ones(1,10)*(-1),age(4)=1.另外还包括输入图像的性别向量gender.

# generator: z + label --> generated image
self.G = self.generator(
    z=self.z,
    y=self.age,
    gender=self.gender,
    enable_tile_label=self.enable_tile_label,
    tile_ratio=self.tile_ratio
)

首先将特征向量z先后与年龄,性别向量串联,

def concat_label(x, label, duplicate=1):
    x_shape = x.get_shape().as_list()
    if duplicate < 1:
        return x
    # duplicate the label to enhance its effect, does it really affect the result?
    label = tf.tile(label, [1, duplicate])
    label_shape = label.get_shape().as_list()
    if len(x_shape) == 2:
        return tf.concat([x, label],1)
    elif len(x_shape) == 4:
        label = tf.reshape(label, [x_shape[0], 1, 1, label_shape[-1]])
        return tf.concat( [x, label*tf.ones([x_shape[0], x_shape[1], x_shape[2], label_shape[-1]])],3)
if enable_tile_label:
    
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值