GAN入门基础

介绍

GAN一直是深度学习比较火的一种模型,在各个领域都有应用,不论是在CV、NLP、还是AR、VR,GAN的加入都让他们更加立体生动。最近元宇宙这个概念被炒的比较火热,这其中GAN就发挥了巨大的作用。来看几组GAN的例(妹)子。
在这里插入图片描述在这里插入图片描述
是心动呀~当然是对GAN的,下面让我们来看一下什么是GAN吧。

什么是GAN

生成式对抗网络(GAN, Generative Adversarial Networks )是一种深度学习模型,是近年来在复杂分布上无监督学习最具前景的方法之一。

在GAN模型中,一般存在两个模块:生成模型(Generative Model)和判别模型(Discriminative Model),两者的相互博弈和学习使得他们共同进步并最终产生出真假难分的目标。
生成器就好比一个制造假币的骗子,而判别器就像是检验假币的警察,在最开始的时候这个骗子制造的假币的质量是很差的,一眼就能被警察识别,这个时候骗子就需要提升自己的造假技术去骗过警察,随着骗子制造出的假币越来越像真的,警察就很难辨别出假币的真假,这时候警察也需要通过学习提升自己的鉴别能力,就这样警察和骗子不断的欺骗和学习,最终生成真假难辨的纸币。

GAN的网络框架
在这里插入图片描述
以生成图片为例子,Generator是一个生成网络,它接收一个随机噪声noise,通过这个噪声生成图片,记作Generator Data,Discriminator是一个判别网络,判别一张图片是不是真实的,它的输出是一个概率值,如果为1,就代表100%是真实的图片,如果为0就代表100%是假的图片。

生成器实现代码

以手写体数据集MNIST为例,生成网络的输入是一行正态分布的随机数,因此它的输入是一个长度为N的一维向量,输出是一个(28,28,1)维的图片。下面的代码基于Keras框架。

def build_generator(self):
    # --------------------------------- #
    #   生成器,输入一串随机数字
    # --------------------------------- #
    model = Sequential()

    model.add(Dense(256, input_dim=self.latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(1024))
    model.add(LeakyReLU(alpha=0.2))
    model.add(BatchNormalization(momentum=0.8))

    model.add(Dense(np.prod(self.img_shape), activation='tanh'))
    model.add(Reshape(self.img_shape))

    noise = Input(shape=(self.latent_dim,))
    img = model(noise)

    return Model(noise, img)

判别器实现代码

判别器的目的是根据输入图片判断出真假。因此它的是一个(28,28,1)维的图片,输出是0到1之间的数,1代表这个图片是真的,0代表这个图片是假的。

def build_discriminator(self):
    # ----------------------------------- #
    #   评价器,对输入进来的图片进行评价`在这里插入代码片`
    # ----------------------------------- #
    model = Sequential()
    # 输入一张图片
    model.add(Flatten(input_shape=self.img_shape))
    model.add(Dense(512))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dense(256))
    model.add(LeakyReLU(alpha=0.2))
    # 判断真伪
    model.add(Dense(1, activation='sigmoid'))

    img = Input(shape=self.img_shape)
    validity = model(img)

    return Model(img, validity)

GAN模型的优化训练

在训练过程中,生成器的目标就是要尽可能的生成真实的图片去欺骗判别器。而判别器的目标就是尽可能把生成器生成的真实图片区分开来,这样一来,生成器和判别器构成了一个动态的博弈过程。为了深入的理解这个博弈过程,我们先来了解一下什么是纳什均衡。

纳什均衡是指博弈中出现这样的局面,对于每个参与者,只要其他参与者不改变决策,他自己就不能改变策略。对应到GAN上,就是生成器制造出了和真实数据一模一样的数据,判别器再也判别不出来结果,准确率为50%,越等于乱猜,这是双方网络都得到了利益最大化,不改变自己的策略,也就是不再更新自己的网络权重。

GAN模型的目标函数如下:
在这里插入图片描述
这样对抗训练之后,效果可能有几个过程,原论文画出的图如下:
在这里插入图片描述
黑色的线表示数据x的实际分布,绿色的线表示数据的生成分布,蓝色的线表示生成的数据对应在判别器中的分布效果
原论文的整体算法:
在这里插入图片描述

全部代码

rom __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os
import numpy as np

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=30000, batch_size=256, sample_interval=200)

最后GAN生成的结果是这样的

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值