c语言贪吃蛇最简单代码_让我们跑一个最简单的GAN网络吧!(附Jupyter Notebook 代码)...

58cbcc57d296132edf4a4326cdadf7a0.png

前言:最近在学习生成对抗网络(GAN, Generative Adversarial Networks),为了加深自己的理解,并帮助到想入门的同学,我特意写了这篇文章,教大家一步步搭建一个最简单原始的GAN网络 (Vanilla GAN)。代码后面会有详细(通俗易懂)的解释,大神请自动绕路~欢迎小白玩家围观~~ 查看本文jupyter notebook代码请点击这里。

daf74f61c7938c376aaa7bc195b0ae1a.gif
利用GAN生成Mnist手写体图像(第1, 5, 10, 50, 400次迭代结果)

首先,让我们简单回顾一下什么是GAN

8439abe8f932da38a131ceac46a7408e.png
图1. GAN网络结构 (来自灵魂画手:我)

GAN最早由GoodFellow在2014年提出,查看原始论文请点击这里。GAN结构如图1所示,包含了一个生成器(Generator)和一个判别器 (Discriminator)。生成器的目的是生成以假乱真的图片,而判别器的目的是尽可能区分输入图片的真假。

举一个简单的例子,比如说假钞的流通。犯罪分子希望制作出逼真的假钞,可是警察的鉴定技术也在不断改良,双方互相博弈,互相提高,最终达到一种动态的平衡。讲到这里,是不是感觉很简单?

鉴于这是一个超级良心的教程~大家可以先跟着我一起,把代码实现。实现过程中,有不懂的先不要问(对,憋着),等跑完代码之后,看到酷炫的效果后,我再一步步解释为啥这么写。

好了,是时候放上代码了。来,先导入包。

from keras.datasets import mnist
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from google.colab import drive

然后,读取Keras自带的mnist数据集。在这里我们给出一个读取数据的函数load_data()。

# Load the dataset
def load_data():
  (x_train, y_train), (_, _) = mnist.load_data()
  x_train = (x_train.astype(np.float32) - 127.5)/127.5
  
  # Convert shape from (60000, 28, 28) to (60000, 784)
  x_train = x_train.reshape(60000, 784)
  return (x_train, y_train)

X_train, y_train = load_data()
print(X_train.shape, y_train.shape)

8d177e1451705650c6ee8934ff185717.png
输出X_train.shape, y_train.shape

由于本文我们旨在实现最原始的GAN网络,因此用最简单MLP全连接层来构建生成器(用卷积层当然更好,在这里先不考虑)

def build_generator():
    model = Sequential()
    
    model.add(Dense(units=256, input_dim=100))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Dense(units=512))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Dense(units=1024))
    model.add(LeakyReLU(alpha=0.2))
    
    model.add(Dense(units=784, activation='tanh'))
    
    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return model

generator = build_generator()
generator.summary()

生成器结构如下图所示:

92d9ebe6a93fc9db472cde19042f2a3a.png

然后建一个判别器,也是一个MLP全连接神经网络:

def build_discriminator():
    model = Sequential()
    
    model.add(Dense(units=1024 ,input_dim=784))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
       
    model.add(Dense(units=512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
       
    model.add(Dense(units=256))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.3))
      
    model.add(Dense(units=1, activation='sigmoid'))
    
    model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return model
  
discriminator = build_discriminator()
discriminator.summary()

判别器结构如图所示:

1d1b1ea44bbc5a6c9666267a9d24729f.png

然后,我们建立一个GAN网络,由discriminator和generator组成。

def build_GAN(discriminator, generator):
    discriminator.trainable=False
    GAN_input = Input(shape=(100,))
    x = generator(GAN_input)
    GAN_output= discriminator(x)
    GAN = Model(inputs=GAN_input, outputs=GAN_output)
    GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
    return GAN

GAN = build_GAN(discriminator, generator)
GAN.summary()

GAN结构如下图所示

5bb4b80deca19b328c7ed48eea4fefb2.png

然后我们给出绘制图像的函数,用于把generator生成的假图片画出来:

def draw_images(generator, epoch, examples=25, dim=(5,5), figsize=(10,10)):
    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(25,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('Generated_images %d.png' %epoch)

OK, 最后一步,写一个train函数,来训练GAN网络。在这里我们设置最大迭代次数400,每次迭代生成128张假图片:

def train_GAN(epochs=1, batch_size=128):
    
  #Loading the data
  X_train, y_train = load_data()

  # Creating GAN
  generator= build_generator()
  discriminator= build_discriminator()
  GAN = build_GAN(discriminator, generator)

  for i in range(1, epochs+1):
    print("Epoch %d" %i)
    
    for _ in tqdm(range(batch_size)):
      # Generate fake images from random noiset
      noise= np.random.normal(0,1, (batch_size, 100))
      fake_images = generator.predict(noise)

      # Select a random batch of real images from MNIST
      real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]

      # Labels for fake and real images           
      label_fake = np.zeros(batch_size)
      label_real = np.ones(batch_size) 

      # Concatenate fake and real images 
      X = np.concatenate([fake_images, real_images])
      y = np.concatenate([label_fake, label_real])

      # Train the discriminator
      discriminator.trainable=True
      discriminator.train_on_batch(X, y)

      # Train the generator/chained GAN model (with frozen weights in discriminator) 
      discriminator.trainable=False
      GAN.train_on_batch(noise, label_real)

    # Draw generated images every 15 epoches     
    if i == 1 or i % 10 == 0:
      draw_images(generator, i)
train_GAN(epochs=400, batch_size=128)

fc4b1fec6cb3ea97b6dbcf0897898624.png

我用了Google colab自带的GPU,训练400代大约用了十多分钟。如果用jupyter notebook在本机跑,会慢一些 (据说2分钟一代?)。

生成的图片如下图所示

968d3f44d8d5c3abf3d8e9860f3d3466.png
第1次迭代

bd407649656270fdf2bd661f8ebbd5e2.png
第10次迭代

644038fc3e83152fe0f3fd1e5149f716.png
第50次迭代

e19a2a8fa7a57fea2b71c99fc2b0f78a.png
第400次迭代

大功告成,接下来我将一步步解释train_GAN()函数是怎么工作的。

首先,导入数据集,这个容易理解。

  #Loading the data
  X_train, y_train = load_data()

接下来,建立一个GAN网络,GAN由两个神经网络(generator, discriminator)连接而成。

  # Creating GAN
  generator= build_generator()
  discriminator= build_discriminator()
  GAN = build_GAN(discriminator, generator)

然后,建立一个循环(400次迭代)。tqdm用来动态显示每次迭代的进度。

 for i in range(1, epochs+1):
    print("Epoch %d" %i)
    for _ in tqdm(range(batch_size)):

接着,我们生成呈高斯分布的噪声,利用generator,来生成batch_size(128张)图片。每张图片的输入就是一个1*100的噪声矩阵。

  # Generate fake images from random noiset
  noise= np.random.normal(0,1, (batch_size, 100))
  fake_images = generator.predict(noise)

同样的,我们从Mnist数据集中随机挑选128张真实图片。我们给真实图片标注1,给假图片标注0,然后将256张真假图片混合在一起。

 # Select a random batch of real images from MNIST
 real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]
 
  # Labels for fake and real images           
  label_fake = np.zeros(batch_size)
  label_real = np.ones(batch_size) 

  # Concatenate fake and real images 
  X = np.concatenate([fake_images, real_images])
  y = np.concatenate([label_fake, label_real])

此时,我们利用上文提到的256张带标签的真假图片,训练discriminator。训练完毕后,discriminator的weights得到了更新。(打个比方,警察通过研究市面上流通的假币,在一起开会讨论,努力研发出了新一代鉴定假钞的方法)。

# Train the discriminator
discriminator.trainable=True
discriminator.train_on_batch(X, y)

然后,我们冻结住discriminator的weights,让discriminator不再变化。然后就开始训练generator (chained GAN)。在GAN的训练中,我们输入一堆噪声,期待的输出是将假图片预测为真。在这个过程中,generator继续生成假图片,送到discriminator检验,得到检验结果,如果被鉴定为假,就不断更新自己的权重(假钞贩子不断改良造假技术),直到discriminator将加图片鉴定为真图片(直到当前鉴定假钞的技术无法识别出假钞)。

 # Train the generator/chained GAN model (with frozen weights in discriminator) 
 discriminator.trainable=False
 GAN.train_on_batch(noise, label_real)

OK,此时一次迭代进行完毕。接下来是第2, 3, ...次迭代。

现在,我们总结一下每次迭代发生了什么:

  1. Generator利用自己最新的权重,生成了一堆假图片。
  2. Discrminator根据真假图片的真实label,不断训练更新自己的权重,直到可以顺利鉴别真假图片。
  3. 此时discriminator权重被固定,不再发生变化。generator利用最新的discrimintor,苦苦思索,不断训练自己的权重,最终使discriminator将假图片鉴定为真图片。

换成印制假钞的例子,每次迭代发生了如下几件事:

  1. 假钞贩子根据最新造假技术,研发出一代假钞。
  2. 警察反复对比新型假钞和真币的区别,成功改良假钞鉴别方法,从而顺利鉴别出市面流通钞票的真伪。
  3. 假钞贩子生成假钞,马上被警察鉴别出来,痛定思痛,改良技术生成新的假钞。不成想,一上街又被警察识别了出来。日复一日,终于发明了新型假钞,当前的验钞技术已经无法成功检测出这种假钞。

然后通过每次迭代,discrimintor (警察的鉴定技术)和generator (假钞制作技术) 都越来越成熟...后来达到了动态平衡。

嗯,就这样,是不是挺简单的?

今天讲的是最原始的GAN网络,GAN发展到了如今已有许多变种,如将MLP结构换成CNN,Autoencoder,以及loss function的变化等等。我在github上找到一个超级全的用keras编写的各种花式GAN网络集合,有兴趣的小伙伴直接点击这里。本文的jupyter notebook代码请直接点击下面的小卡片~

https://nbviewer.jupyter.org/github/gaonanlee/Deep-Learning-Experiments/blob/master/Vanilla%20GAN_implementation.ipynb​nbviewer.jupyter.org

如有理解不到位之处,欢迎批评指教。

参考文献

  1. https://github.com/eriklindernoren/Keras-GAN/blob/master/cgan/cgan.py

2. https://medium.com/datadriveninvestor/generative-adversarial-network-gan-using-keras-ce1c05cfdfd3


我的其他回答:

哪些 Python 库让你相见恨晚?

python如何画出漂亮的地图?

时间序列数据如何插补缺失值?

机器学习中的因果关系: 从辛普森悖论(常见的统计学谬误)谈起

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值