GAN生成对抗网络合集(三):InfoGAN和ACGAN-指定类别生成模拟样本的GAN(附代码)

本文介绍了InfoGAN和AC-GAN两种增强GAN能力的方法。InfoGAN通过学习样本的关键维度信息,实现降维和解耦。AC-GAN在判别器中加入分类信息,使生成的模拟数据与其所属类别对应。文中提供了详细的代码实现,包括模型结构、损失函数和优化器,以及可视化结果展示。
摘要由CSDN通过智能技术生成

1 InfoGAN-带有隐含信息的GAN

       InfoGAN是一种把信息论与GAN相融合的神经网络,能够使网络具有信息解读功能。
       GAN的生成器在构建样本时使用了任意的噪声向量x’,并从低维的噪声数据x’中还原出来高维的样本数据。这说明数据x’中含有具有与样本相同的特征
       由于随意使用的噪声都能还原出高维样本数据,表明噪声中的特征数据部分是与无用的数据部分高度地纠缠在一起的,即我们能够知道噪声中含有有用特征,但无法知道哪些是有用特征
       InfoGAN是GAN模型的一种改进,是一种能够学习样本中的关键维度信息的GAN,即对生成样本的噪音进行了细化。先来看它的结构,相比对抗自编码,InfoGAN的思路正好相反,InfoGAN是先固定标准高斯分布作为网络输入,再慢慢调整网络输出去匹配复杂样本分布

在这里插入图片描述
                                                                                                         图3.1 InfoGAN模型

       如图3.1所示,InfoGAN生成器是从标准高斯分布中随机采样来作为输入,生成模拟样本,解码器是将生成器输出的模拟样本还原回生成器输入的随机数中的一部分,判别器是将样本作为输入来区分真假样本。
       InfoGAN的理论思想是将输入的随机标准高斯分布当成噪音数据,并将噪音分为两类,第一类是不可压缩的噪音Z,第二类是可解释性的信息C。假设在一个样本中,决定其本身的只有少量重要的维度,那么大多数的维度是可以忽略的。而这里的解码器可以更形象地叫成重构器,即通过重构一部分输入的特征来确定与样本互信息的那些维度。最终被找到的维度可以代替原始样本的特征(类似PCA算法中的主成份),实现降维、解耦的效果。

2 AC-GAN-带有辅助分类信息的GAN

       AC-GAN(Auxiliary Classifier GAN),即在判别器discriminator中再输出相应的分类概率,然后增加输出的分类与真实分类的损失计算,使生成的模拟数据与其所属的class一一对应。一般来讲,AC-GAN可以属于InfoGAN的一部分,class信息可以作为InfoGAN中的潜在信息,只不过这部分信息可以使用半监督方式来学习。

在这里插入图片描述

在这里插入图片描述

在这里插入图片描述

3 代码

       首先明确,GAN的代码没有目标检测的复杂,以一个目标检测程序demo的篇幅就涵盖了GAN的数据输入、训练、定义网络结构和参数、loss函数和优化器以及可视化部分。
       还可以学习到的是,GAN基本除开两个大的网络框架G和D以外,就是加各种约束(分类信息、隐含信息等)用以生成想要的数据
       下面是代码实现学习MINST数据特征,生成以假乱真的MNIST模拟样本,并发现内部潜在的特征信息。

在这里插入图片描述
代码总纲

  1. 加载数据集;
  2. 定义G和D;
  3. 定义网络模型的参数、输入输出、中间过程(经过G/D)的输入输出;
  4. 定义loss函数和优化器;
  5. 训练和测试(套循环);
  6. 可视化

3.1 加载数据集、引入头文件

       MNIST数据集下载到相应的地址,其加载方式是固定的。

# -*- coding: utf-8 -*-
##################################################################
#  1.引入头文件并加载mnist数据
##################################################################
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from scipy.stats import norm
import tensorflow.contrib.slim as slim

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/media/S318080208/py_pictures/minist/")  # ,one_hot=True)


tf.reset_default_graph()  # 用于清除默认图形堆栈并重置全局默认图形

3.2 定义G和D

  • 生成器G
    通过“两个全连接+两个反卷积(转置卷积slim.conv2d_transpose)”模拟样本的生成,每一层都有BN(批量归一化)处理。
  • 判别器D
    判别器中有使用leaky_relu函数,其余的在slim库里有,不用重新定义;
    判别器也是由“两次卷积+两次全连接”组成。生成的数据可以分别连接不同的输出层产生不同的结果,其中1维的输出层产生判别结果1或0,10维的输出层产生分类结果,2维输出层产生隐含维度信息。
##################################################################
#  2.定义生成器与判别器
##################################################################
def generator(x):  # 生成器函数 : 两个全连接+两个反卷积模拟样本的生成,每一层都有BN(批量归一化)处理
    reuse = len([t for t in tf.global_variables() if t.name.startswith('generator')]) > 0   # 确认该变量作用域没有变量
    # print (x.get_shape())
    with tf.variable_scope('generator', reuse=reuse):
        x = slim.fully_connected(x, 1024)
        # print(x)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = slim.fully_connected(x, 7*7*128)
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        x = tf.reshape(x, [-1, 7, 7, 128])
        # print ('22', tf.tensor.get_shape())
        x = slim.conv2d_transpose(x, 64, kernel_size=[4, 4], stride=2, activation_fn = None)
        # print ('gen',x.get_shape())
        x = slim.batch_norm(x, activation_fn=tf.nn.relu)
        z = slim.conv2d_transpose(x, 1, kernel_size=[4, 4], stride=2, activation_fn=tf.nn.sigmoid)
        # print ('genz',z.get_shape())
    return z


def leaky_relu
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值