plt生成固定的colormap_白话生成对抗网络GAN及代码实现

v2-9c8e01b0acd95923d8aa7987700d8440_1440w.jpg?source=172ae18b

本文主要是个简单的笔记,参考资料来自下面三部分

  1. Tutorial_HYLee_GAN
  2. Renu Khandelwal 的博客
  3. Jason 的博客

神经网络一览

各种神经网络(全连接前向网络、卷积神经网络、循环神经网络)的区别在于具有不同的输入/输出形式,比如可以是向量、矩阵或者是向量序列等。

v2-057dc13b951f864921885ae99e403d94_b.jpg

GAN的基本思想

GAN由生成器和判别器组成:

生成器的本质也是一个神经网络,或者说是一个函数

v2-ff7b89d79ea3d7a858d30cd3bf194a31_b.jpg

如果给定一个向量可以生成一张漫画图片,向量的每一个维度具有不同含义

v2-64bffe610c787d453ad9e2a0ae6e36ed_b.jpg

判别器的本质也是一个神经网络

v2-10959f8d2da50c24128c85c9554a0bd3_b.jpg

如果给定一张图片,判别器就会告诉你这是不是真实图片

v2-92fea0d79ce3c962972849b9d5316c71_b.jpg

所以GAN的训练本质就是训练两个神经网络

GAN的工作原理

生成器的目标是产生和训练数据相似的数据(以假乱真的图片),而判别器的目标是辨别真假。

生成器的输入通常为随机噪声,判别器有两个输入,一个来自训练数据中的真图片,一个来自生成器生成的假图片。

GAN的流程如下图所示

v2-c2794c1c83e939abb6984cdcb67b2c6e_b.jpg

每一次迭代过程中:

  1. 更新判别器的网络参数。即给定假图片以及假图片的标签(上图中的generated example)、真图片以及真图片的标签(上图中的real example),让判别器能够区别出真假图片,也就是训练一个尽可能准确的二分类器。
  2. 固定判别器网络参数, 更新生成器网络。即给定假图片以及假标签(让判别器以为假图片是真的),从而误差反向传播来更新生成器,使得生成器生成更加逼真的照片。

GAN训练的目标函数如下所示

v2-d851ce5335f6a541c5648352b9785b5f_b.jpg
  • 判别器想要最大化目标函数使得对于真实数据 D(x) 接近 1,对于假数据 D(G(z)) 接近 0
  • 生成器想要最小化目标函数使得 D(G(z)) 接近 1,也就是欺骗判别器让它认为假数据为真

GAN的实现

这里采用 MNIST 数据集作为实验数据,最后我们会看到生成器能够产生看起来像真的数字!

导入需要用到的库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
import keras
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.datasets import mnist
from tqdm import tqdm
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam

导入数据

def load_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5

    # 将图片转为向量 x_train from (60000, 28, 28) to (60000, 784) 
    # 每一行 784 个元素
    x_train = x_train.reshape(60000, 784)
    return (x_train, y_train, x_test, y_test)
(X_train, y_train,X_test, y_test)=load_data()
print(X_train.shape)

定义优化器

def adam_optimizer():
    return Adam(lr=0.0002, beta_1=0.5)

这里要采用的生成对抗网络的结构如下图所示

v2-2699c06f0f5d8bffa7f786ed7d7955d9_b.jpg

定义生成器:输入是 100 维,经过三层隐藏层,输出 784 维的向量(造假的图片)

def create_generator():
    generator=Sequential()
    generator.add(Dense(units=256,input_dim=100))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(units=512))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(units=1024))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(units=784, activation='tanh'))

    generator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
    return generator
g=create_generator()
g.summary()

定义判别器:判别器的输入为真实图片或者由生成器造出来的假图片(784维),经过三层隐藏层,输出类别(1 维)

def create_discriminator():
    discriminator=Sequential()
    discriminator.add(Dense(units=1024,input_dim=784))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))


    discriminator.add(Dense(units=512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(units=256))
    discriminator.add(LeakyReLU(0.2))

    discriminator.add(Dense(units=1, activation='sigmoid'))

    discriminator.compile(loss='binary_crossentropy', optimizer=adam_optimizer())
    return discriminator
d =create_discriminator()
d.summary()

定义生成对抗网络

def create_gan(discriminator, generator):
    discriminator.trainable=False
    # 这是一个链式模型:输入经过生成器、判别器得到输出
    gan_input = Input(shape=(100,))
    x = generator(gan_input)
    gan_output= discriminator(x)
    gan= Model(inputs=gan_input, outputs=gan_output)
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan
gan = create_gan(d,g)
gan.summary()

定义画图函数来可视化图片的生成

def plot_generated_images(epoch, generator, examples=100, dim=(10,10), figsize=(10,10)):
    noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(100,28,28)
    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i+1)
        plt.imshow(generated_images[i], interpolation='nearest')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig('gan_generated_image %d.png' %epoch)

生成对抗网络的训练函数

def training(epochs=1, batch_size=128):

    #导入数据
    (X_train, y_train, X_test, y_test) = load_data()
    batch_count = X_train.shape[0] / batch_size

    # 定义生成器、判别器和GAN网络
    generator= create_generator()
    discriminator= create_discriminator()
    gan = create_gan(discriminator, generator)

    for e in range(1,epochs+1 ):
        print("Epoch %d" %e)
        for _ in tqdm(range(int(batch_count))):
            #产生噪声喂给生成器
            noise= np.random.normal(0,1, [batch_size, 100])

            # 产生假图片
            generated_images = generator.predict(noise)

            # 一组随机真图片
            image_batch =X_train[np.random.randint(low=0,high=X_train.shape[0],size=batch_size)]

            # 真假图片拼接 
            X= np.concatenate([image_batch, generated_images])

            # 生成数据和真实数据的标签
            y_dis=np.zeros(2*batch_size)
            y_dis[:batch_size]=0.9

            # 预训练,判别器区分真假
            discriminator.trainable=True
            discriminator.train_on_batch(X, y_dis)

            # 欺骗判别器 生成的图片为真的图片
            noise= np.random.normal(0,1, [batch_size, 100])
            y_gen = np.ones(batch_size)

            # GAN的训练过程中判别器的权重需要固定 
            discriminator.trainable=False

            # GAN的训练过程为交替“训练判别器”和“固定判别器权重训练链式模型”
            gan.train_on_batch(noise, y_gen)

        if e == 1 or e % 50 == 0:
            # 画图 看一下生成器能生成什么
            plot_generated_images(e, generator)
training(400,256)

经过训练后生成的图片

一个epoch后生成器还是个小学生

v2-6eb4600b0e8a16c095ef3b95cc06ae74_b.jpg

100个epoch后生成器已经有点样子了

v2-260ea6d134ffe04e66f97243bfcbb95b_b.jpg

400个epoch后生成器可以出师了

v2-7b79e0bc8ec29188fa29bfe1294a9cb4_b.jpg

是不是已经学得像模像样了,这样就能够利用噪声通过生成器来生成以假乱真的图片了。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值