实例99:使用AEGAN对MNIST数据集压缩特征及重建

  本实例在MNIST数据集上使用AEGAN模型进行特征压缩及重建,并且加入标签信息loss实现AC-GAN网络。其中D和G都是通过卷积网络实现。

实例描述

  使用InfoGAN网络,在其基础上添加自编码网络,将InfoGAN的参数固定,训练反向生成器(自编码网络中的编码器),并将生成的模型用于MNIST数据集样本重建,得到相似的样本。

1.添加反向生成器

  添加反向生成器inversegenerator函数。该函数的功能是将图片生成特征吗,其结构与判别器相似,均为生成器的反向操作,即两个卷积层加上两个全连接层。

#反向生成器定义,结构与判别器类似
def inversegenerator(x):
    reuse = len([t for t in tf.global_variables() if t.name.startswith('inversegenerator')]) > 0
    with tf.variable_scope('inversegenerator', reuse=reuse):
        #两个卷积
        x = tf.reshape(x, shape=[-1, 28, 28, 1])
        x = slim.conv2d(x, num_outputs = 64, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)
        x = slim.conv2d(x, num_outputs=128, kernel_size=[4,4], stride=2, activation_fn=leaky_relu)
        #两个全连接
        x = slim.flatten(x)        
        shared_tensor = slim.fully_connected(x, num_outputs=1024, activation_fn = leaky_relu)
        z = slim.fully_connected(shared_tensor, num_outputs=50, activation_fn = leaky_relu)
    return z  

2.添加自编码网络代码

  自编码网络输入不是真实图片,而是生成器生成的图片generator(z),通过inversegenerator来压缩特征,生成与生成器输入噪声一样维度,然后将生成器当做自编码中的解码器重建出原始生成的图片。
  将自编码还原的图片与GAN生成生成的输入图片进行平方差计算,得到自编码的损失值loss_ae。

z_con = tf.random_normal((batch_size, con_dim))      #2列
z_rand = tf.random_normal((batch_size, rand_dim))    #38列
z = tf.concat(axis=1, values=[tf.one_hot(y, depth = classes_dim), z_con, z_rand])#50列
#生成器生成的图片gen模拟数据
gen = generator(z)
genout= tf.squeeze(gen, -1)


#自编码网络
aelearning_rate =0.01
igen = generator(inversegenerator(generator(z)))    #生成器生成的模拟数据-反生成原始数据-再生成图片的数据
loss_ae = tf.reduce_mean(tf.pow(gen - igen, 2))

#输出
igenout = generator(inversegenerator(x))

3.添加自编码网络的训练参数列表,定义优化器

  自编码网络的训练参数与前面的GAN几乎一样,使用MonitoredTrainingSession来管理检查点文件,定义global_step。定义train_ae优化器,并将global_step放入优化器中。

# 获得各个网络中各自的训练参数
t_vars = tf.trainable_variables()
d_vars = [var for var in t_vars if 'discriminator' in var.name]
g_vars = [var for var in t_vars if 'generator' in var.name]
ae_vars =  [var for var in t_vars if 'inversegenerator' in var.name]

gen_global_step = tf.Variable(0, trainable=False)
global_step = tf.train.get_or_create_global_step()#使用MonitoredTrainingSession,必须有

train_disc = tf.train.AdamOptimizer(0.0001).minimize(loss_d + loss_c + loss_con, var_list = d_vars, global_step = global_step)
train_gen = tf.train.AdamOptimizer(0.001).minimize(loss_g + loss_c + loss_con, var_list = g_vars, global_step = gen_global_step)
train_ae = tf.train.AdamOptimizer(aelearning_rate).minimize(loss_ae, var_list = ae_vars, global_step = global_step)


training_GANepochs = 3   #训练GAN迭代3次数据集
training_aeepochs = 6    #训练AE迭代3次数据集(从3开始到6)
display_step = 1

  本例需要训练GAN和AE两个网络,使用MonitoredTrainingSession管理后只能由有一个global_step,于是将global_step分段来管理两个网络的训练。每次迭代训练都会遍历整个数据集,先GAN迭代3次,在让AE迭代3测。

4.起动session依次训练GAN和AE网络

  使用MonitoredTrainingSession创建session。令程序每2分钟保存一次检查点文件。

with tf.train.MonitoredTrainingSession(checkpoint_dir='log/aecheckpoints',save_checkpoint_secs  =120) as sess:
    
    total_batch = int(mnist.train.num_examples/batch_size)
    print("ae_global_step.eval(session=sess)",global_step.eval(session=sess),int(global_step.eval(session=sess)/total_batch))
    
    for epoch in range( int(global_step.eval(session=sess)/total_batch),training_GANepochs):
        avg_cost = 0.

        # 遍历全部数据集
        for i in range(total_batch):

            batch_xs, batch_ys = mnist.train.next_batch(batch_size)#取数据
            feeds = {x: batch_xs, y: batch_ys}

            # Fit training using batch data
            l_disc, _, l_d_step = sess.run([loss_d, train_disc, global_step],feeds)
            l_gen, _, l_g_step = sess.run([loss_g, train_gen, gen_global_step],feeds)

        # 显示训练中的详细信息
        if epoch % display_step == 0:
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f} ".format(l_disc),l_gen)

    print("GAN完成!")
    # 测试
    print ("Result:", loss_d.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]},session = sess)
                        , loss_g.eval({x: mnist.test.images[:batch_size],y:mnist.test.labels[:batch_size]},session = sess))

在这里插入图片描述
在这里插入图片描述
  从图中可以看出,InfoGAN只会生成属于原始数据分布的图片,而AEGAN会生成与原始图片更相近的图片。
  这种网络有压缩特征与重建两部分用途,重建样本常常用于处理图像的恢复与重建,还可以将重建的模拟数据保存起来空充数据集,也可以应用在超分辨率重建部分;

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值