数据集难找?GAN生成你想要的数据!!!

GAN生成对抗网络学习笔记

1.GAN诞生背后的故事:

GAN创始人 Ian Goodfellow 在酒吧微醉后与同事讨论学术问题,当时灵光乍现提出了GAN初步的想法,不过当时并没有得到同事的认可,在从酒吧回去后发现女朋友已经睡了,于是自己熬夜写了代码,发现还真有效果,于是经过一番研究后,GAN就诞生了,一篇开山之作。论文《Generative Adversarial Nets》首次提出GAN。

论文链接:https://arxiv.org/abs/1406.2661

2.GAN的原理:

GAN的主要灵感来源于博弈论中零和博弈的思想,应用到深度学习神经网络上来说,就是通过生成网络G(Generator)和判别网络D(Discriminator)不断博弈,进而使G学习到数据的分布,如果用到图片生成上,则训练完成后,G可以从一段随机数中生成逼真的图像。G, D的主要功能是:

  •   G是一个生成式的网络,它接收一个随机的噪声z(随机数),通过这个噪声生成图像 

  • D是一个判别网络,判别一张图片是不是“真实的”。它的输入参数是x,x代表一张图片,输出D(x)代表x为真实图片的概率,如果为1,就代表100%是真实的图片,而输出为0,就代表不可能是真实的图片

训练过程中,生成网络G的目标就是尽量生成真实的图片去欺骗判别网络D。而D的目标就是尽量辨别出G生成的假图像和真实的图像。这样,G和D构成了一个动态的“博弈过程”,最终的平衡点即纳什均衡点.


通俗意思就是在犯罪分子造假币和警察识别假币的过程中            

     [1]生成模型G相当于制造假币的一方,其目的是根据看到的钱币情况和警察的识别技术,去尽量生成更加真实的、警察识别不出的假币。           

     [2]判别模型D相当于识别假币的一方,其目的是尽可能的识别出犯罪分子制造的假币。这样通过造假者和识假者双方的较量和朝目的的改进,使得最后能达到生成模型能尽可能真的钱币、识假者判断不出真假的纳什均衡效果(真假币概率都为0.5)。


如图所示:


3.GAN的原理图:


4.GAN的特点:

  1.  相比较传统的模型,他存在两个不同的网络,而不是单一的网络,并且训练方式采用的是对抗训练方式

  2. GAN中G的梯度更新信息来自判别器D,而不是来自数据样本


5.GAN 的优点:

  1. GAN是一种生成式模型,相比较其他生成模型(玻尔兹曼机和GSNs)只用到了反向传播,而不需要复杂的马尔科夫链

  2. 相比其他所有模型, GAN可以产生更加清晰,真实的样本

  3. GAN采用的是一种无监督的学习方式训练,可以被广泛用在无监督学习和半监督学习领域

  4. 相比于变分自编码器, GANs没有引入任何决定性偏置( deterministic bias),变分方法引入决定性偏置,因为他们优化对数似然的下界,而不是似然度本身,这看起来导致了VAEs生成的实例比GANs更模糊

  5. 相比VAE, GANs没有变分下界,如果鉴别器训练良好,那么生成器可以完美的学习到训练样本的分布.换句话说,GANs是渐进一致的,但是VAE是有偏差的

  6. GAN应用到一些场景上,比如图片风格迁移,超分辨率,图像补全,去噪,避免了损失函数设计的困难,不管三七二十一,只要有一个的基准,直接上判别器,剩下的就交给对抗训练了。


6.GAN的缺点:

  1. 训练GAN需要达到纳什均衡,有时候可以用梯度下降法做到,有时候做不到.我们还没有找到很好的达到纳什均衡的方法,所以训练GAN相比VAE或者PixelRNN是不稳定的,但我认为在实践中它还是比训练玻尔兹曼机稳定的多

  2. GAN不适合处理离散形式的数据,比如文本

  3. GAN存在训练不稳定、梯度消失、模式崩溃的问题(目前已解决)


7.训练GAN的一些技巧:

  1.  输入规范化到(-1,1)之间,最后一层的激活函数使用tanh(BEGAN除外)

  2.  使用wassertein GAN的损失函数,

  3. 如果有标签数据的话,尽量使用标签,也有人提出使用反转标签效果很好,另外使用标签平滑,单边标签平滑或者双边标签平滑

  4.  使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm

  5. 避免使用RELU和pooling层,减少稀疏梯度的可能性,可以使用leakrelu激活函数

  6. 优化器尽量选择ADAM,学习率不要设置太大,初始1e-4可以参考,另外可以随着训练进行不断缩小学习率,

  7. 给D的网络层增加高斯噪声,相当于是一种正则。


8.GAN的延伸有哪些:

DCGAN
CGAN
ACGAN
infoGAN
WGAN
SSGAN
Pix2Pix GAN
Cycle  GAN

9.GAN可以做什么:答案是生成数据

生成音频
生成图片(动物:猫,狗等;人脸图片,人脸图转动漫图等)
.......

先来个美食图缓一缓(学累就先吃一点东西,哈哈哈)

继续!!!!!

10.GAN的经典案例:生成手写数字图片

  • 源码和数据集获取方式在下方

  • 有py格式和ipynb格式两种(代码是一样的)

代码如下:

# -*- coding: utf-8 -*-
"""
Created on 2020-10-31


@author: 李运辰
"""
#导入数据包
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
#get_ipython().run_line_magic('matplotlib', 'inline')
import numpy as np
import glob
import os


# # 输入
(train_images,train_labels),(_,_)=tf.keras.datasets.mnist.load_data()


train_images  = train_images.astype('float32')


# # 数据预处理
train_images=train_images.reshape(train_images.shape[0],28,28,1).astype('float32')


#归一化 到【-1,1】
train_images = (train_images -127.5)/127.5


BTATH_SIZE=256
BUFFER_SIZE=60000


#输入管道
datasets = tf.data.Dataset.from_tensor_slices(train_images)


#打乱乱序,并取btath_size
datasets  =  datasets.shuffle(BUFFER_SIZE).batch(BTATH_SIZE)


# # 生成器模型
def generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(256,input_shape=(100,),use_bias=False))
    #Dense全连接层,input_shape=(100,)长度100的随机向量,use_bias=False,因为后面有BN层
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#激活
    
    #第二层
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#激活
    
    #输出层
    model.add(layers.Dense(28*28*1,use_bias=False,activation='tanh'))
    model.add(layers.BatchNormalization())
    
    model.add(layers.Reshape((28,28,1)))#变成图片 要以元组形式传入
    
    return model
    
# # 辨别器模型
def discriminator_model():
    model = keras.Sequential()
    model.add(layers.Flatten())
    
    model.add(layers.Dense(512,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#激活
    
    model.add(layers.Dense(256,use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())#激活
    
    model.add(layers.Dense(1))#输出数字,>0.5真实图片
    
    return model
    
# # loss函数
cross_entropy=tf.keras.losses.BinaryCrossentropy(from_logits=True)#from_logits=True因为最后的输出没有激活


# # 生成器损失函数
def generator_loss(fake_out):#希望fakeimage的判别输出fake_out判别为真
    return cross_entropy(tf.ones_like(fake_out),fake_out)




# # 判别器损失函数
def discriminator_loss(real_out,fake_out):#辨别器的输出 真实图片判1,假的图片判0
    real_loss=cross_entropy(tf.ones_like(real_out),real_out)
    fake_loss=cross_entropy(tf.zeros_like(fake_out),fake_out)
    return real_loss+fake_loss 




# # 优化器


generator_opt=tf.keras.optimizers.Adam(1e-4)#学习速率
discriminator_opt=tf.keras.optimizers.Adam(1e-4)


EPOCHS=500
noise_dim=100 #长度为100的随机向量生成手写数据集
num_exp_to_generate=16 #每步生成16个样本
seed=tf.random.normal([num_exp_to_generate,noise_dim]) #生成随机向量观察变化情况


# # 训练
generator=generator_model()
discriminator=discriminator_model()




# # 定义批次训练函数
def train_step(images):
    noise = tf.random.normal([num_exp_to_generate,noise_dim])
    
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        #判别真图片
        real_out = discriminator(images,training=True)
        #生成图片
        gen_image = generator(noise,training=True)
        #判别生成图片
        fake_out = discriminator(gen_image,training=True)
        
        
        #损失函数判别
        gen_loss = generator_loss(fake_out)
        disc_loss = discriminator_loss(real_out,fake_out)
    
    #训练过程
    #生成器与生成器可训练参数的梯度
    gradient_gen = gen_tape.gradient(gen_loss,generator.trainable_variables) 
    gradient_disc = disc_tape.gradient(disc_loss,discriminator.trainable_variables)
    
    #优化器优化梯度
    generator_opt.apply_gradients(zip(gradient_gen,generator.trainable_variables))
    discriminator_opt.apply_gradients(zip(gradient_disc,discriminator.trainable_variables))
        
# # 可视化
def generator_plot_image(gen_model,test_noise):
    pre_images = gen_model(test_noise,training=False)
    #绘图16张图片在一张4x4
    fig = plt.figure(figsize=(4,4))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,4,i+1) #从1开始排
        plt.imshow((pre_images[i,:,:,0]+1)/2,cmap='gray') #归一化,灰色度
        plt.axis('off') #不显示坐标轴
    plt.show()


def train(dataset,epochs):
     for epoch in range(epochs):
        for image_batch in dataset:
            train_step(image_batch)
        #print('第'+str(epoch+1)+'次训练结果')
        if epoch%10==0:
            print('第'+str(epoch+1)+'次训练结果')
            generator_plot_image(generator,seed)


train(datasets,EPOCHS)




训练结果:

  • 第1次训练结果

  • 第100次训练结果

结论:

在100次训练后,可以明显看到数字的内容,到训练了300次之后就可以很清楚看到生成的数字效果,但300次之后,400,500次效果逐渐下降。图片内容变模糊。

正文结束!!!!

源码和数据集获取方法

公众号回复【GAN】免费获取

欢迎关注公众号:Python爬虫数据分析挖掘

记录学习python的点点滴滴;

回复【开源源码】免费获取更多开源项目源码;

公众号每日更新python知识和【免费】工具;

本文已同步到【开源中国】和【腾讯云社区】;

  • 4
    点赞
  • 40
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

lyc2016012170

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值