GAN生成对抗网络----手写数据实现

该博客介绍了GAN(生成对抗网络)的工作原理,将其比喻为警察与罪犯的博弈过程。文章详细展示了如何构建和训练GAN的判别器和生成器模型,使用Tensorflow 2.4.1实现,并应用在MNIST手写数字数据集上。训练过程包括生成器和判别器的交替训练,以达到生成逼真图像的效果。最后,博客提供了训练过程中的结果展示和代码实现。
摘要由CSDN通过智能技术生成

GAN------ 以假乱真

GAN 的基本理念其实非常简单,其核心由两个目标互相冲突的神经网络组成,这两个网络会以越来越复杂的方法来“蒙骗”对方。这种情况可以理解为博弈论中的极大极小博弈树。

在这个过程中,我们想象有两类人:警察和罪犯。我们看看他们的之间互相冲突的目标:

  • 罪犯的目标:他的主要目标就是想出伪造货币的复杂方法,从而让警察无法区分假币和真币。
  • 警察的目标:他的主要目标就是想出辨别货币的复杂方法,这样就能够区分假币和真币。

随着这个过程不断继续,警察会想出越来越复杂的技术来鉴别假币,罪犯也会想出越来越复杂的技术来伪造货币。这就是 GAN 中“对抗过程”的基本理念。

GAN 充分利用“对抗过程”训练两个神经网络,这两个网络会互相博弈直至达到一种理想的平衡状态,我们这个例子中的警察和罪犯就相当于这两个神经网络。

其中一个神经网络叫做生成器网络 G(Z),它会使用输入随机噪声数据,生成和已有数据集非常接近的数据;

另一个神经网络叫鉴别器网络 D(X),它会以生成的数据作为输入,尝试鉴别出哪些是生成的数据,哪些是真实数据。鉴别器的核心是实现二元分类,输出的结果是输入数据来自真实数据集(和合成数据或虚假数据相对)的概率。

我们在前面所说的 GAN 最终能达到一种理想的平衡状态,是指生成器应该能模拟真实的数据,鉴别器输出的概率应该为 0.5, 即生成的数据和真实数据一致。也就是说,它不确定来自生成器的新数据是真实还是虚假,二者的概率相等。

训练流程

在这里插入图片描述

环境

  • tensorflow 2.4.1
  • numpy
  • matplotlib

数据集

mnist 手写数字

完整代码

'''
tensorflow 2.4.1
numpy
matplotlib
'''
# 设置GPU内存按需分配
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)
import numpy as np
import time
import cv2 as cv
from tensorflow.keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()


from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Activation,Flatten,Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, UpSampling2D
from tensorflow.keras.layers import LeakyReLU, Dropout
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import Adam,RMSprop


import matplotlib.pyplot as plt

class ElapsedTimer(object):
    def __init__(self):
        self.start_time = time.time()
    def elapsed(self,sec):
        if sec < 60:
            return str(sec) + " sec"
        elif sec < (60 * 60):
            return str(sec / 60) + " min"
        else:
            return str(sec / (60 * 60)) + " hr"
    def elapsed_time(self):
        print("Elapsed: %s " % self.elapsed(time.time() - self.start_time) )

class DCGAN(object):
    def __init__(self, img_rows=28, img_cols=28, channel=1):

        self.img_rows = img_rows
        self.img_cols = img_cols
        self.channel = channel
        self.D = None   # discriminator
        self.G = None   # generator
        self.AM = None  # adversarial model
        self.DM = None  # discriminator model

    # (W−F+2P)/S+1
    # 判别模型
   # 14 * 14 * 1
   # 返回一个置信度
    def discriminator(self):
        if self.D:
            return self.D
        self.D = Sequential()
        depth = 64
        dropout = 0.4
        # In: 28 x 28 x 1, depth = 1
        # Out: 14 x 14 x 1, depth=64
        input_shape = (self.img_rows, self.img_cols, self.channel) # 14*14*1 的img
        """
        padding = “SAME”输入和输出大小关系:
            输出大小等于输入大小除以步长向上取整
            
        padding = “VALID”输入和输出大小关系:
            输出大小等于输入大小减去滤波器大小加上1,最后再除以步长
        """
        """
        64个5*5大小的内核,步长为2,🔠input:(14,14,1),padding=‘same’保证intput和output一样
        """
        self.D.add(Conv2D(64, 5, strides=2, input_shape=input_shape,padding='same'))# 14*14*64
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(128, 5, strides=2, padding='same')) # 7*7*128
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(256, 5, strides=2, padding='same')) # 4*4*256 向上取整
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(512, 5, strides=1, padding='same')) # 4*4*512
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        self.D.add(Conv2D(256, 5, strides=1, padding='same')) # 4*4*256
        self.D.add(LeakyReLU(alpha=0.2))
        self.D.add(Dropout(dropout))

        # Out: 1-dim probability
        self.D.add(Flatten())#扁平  4096=4*4*256
        self.D.add(Dense(1)) # 输出 1个
        self.D.add(Activation('sigmoid')) # 二分类
        self.D.summary()
        return self.D

    # 生成模型
    # 全连接  7*7*256
    # 返回一张图 28*28*1
    def generator(self):
        if self.G:
            return self.G
        self.G = Sequential()
        dropout = 0.4
        depth = 64+64+64+64
        dim = 7
        # In: 100
        # Out: dim x dim x depth
        self.G.add(Dense(dim*dim*depth, input_dim=100))#全连接  7*7*256 的大小
        """
        参数作用于mean和variance的计算上, 这里保留了历史batch里的mean和variance值,即 moving_mean和moving_variance, 
        借鉴优化算法里的momentum算法将历史batch里的mean和variance的作用延续到当前batch. 一般momentum的值为0.9 , 0.99等. 
        多个batch后, 即多个0.9连乘后,最早的batch的影响会变弱.
        """
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))
        self.G.add(Reshape((dim, dim, depth))) # 7*7*256
        self.G.add(Dropout(dropout))

        # In: dim x dim x depth
        # Out: 2*dim x 2*dim x depth/2
        self.G.add(UpSampling2D()) # 翻倍 14*14*256
        """
        输入图像通过卷积操作提取特征后,输出的尺寸常会变小,而有时我们需要将图像恢复到原来的尺寸以便进行进一步的计算(比如:图像的语义分割),
        那么我们需要实现图像由小分辨率到大分辨率的映射的操作,叫做上采样(Upsample)。
        """
        self.G.add(Conv2DTranspose(int(depth/2), 5, padding='same')) # 反卷积 14*14*128
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))

        self.G.add(UpSampling2D())# 28*28*128
        self.G.add(Conv2DTranspose(int(depth/4), 5, padding='same')) # 28*28*64
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))

        self.G.add(Conv2DTranspose(int(depth/8), 5, padding='same')) # 28*28*32
        self.G.add(BatchNormalization(momentum=0.9))
        self.G.add(Activation('relu'))

        # Out: 28 x 28 x 1 grayscale image [0.0,1.0] per pix
        self.G.add(Conv2DTranspose(1, 5, padding='same')) # 28*28*1 输出一张特征图(就是生成的图像)
        self.G.add(Activation('sigmoid'))
        self.G.summary()
        return self.G

    def discriminator_model(self):
        if self.DM:
            return self.DM
        optimizer = RMSprop(lr=0.0002, decay=6e-8)
        self.DM = Sequential()
        self.DM.add(self.discriminator())
        # print("DM")
        # self.DM.summary()
        self.DM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.DM

    def adversarial_model(self):
        if self.AM:
            return self.AM
        optimizer =RMSprop(lr=0.0001, decay=3e-8)
        self.AM = Sequential()
        self.AM.add(self.generator())
        self.AM.add(self.discriminator())
        # print('AM')
        # self.AM.summary()
        self.AM.compile(loss='binary_crossentropy', optimizer=optimizer,\
            metrics=['accuracy'])
        return self.AM

class MNIST_DCGAN(object):
    def __init__(self):
        self.img_rows = 28
        self.img_cols = 28
        self.channel = 1
        (X_train, y_train), (X_test, y_test) = mnist.load_data()
        X_train = X_train / 255.0
        self.x_train = X_train.reshape(-1, 28, 28, 1).astype(np.float32)


        self.DCGAN = DCGAN()
        self.discriminator =  self.DCGAN.discriminator_model()
        self.adversarial = self.DCGAN.adversarial_model()
        self.generator = self.DCGAN.generator()
    def train(self, train_steps=2000, batch_size=256, save_interval=0):
        noise_input = None
        if save_interval>0:
            noise_input = np.random.uniform(-1.0, 1.0, size=[16, 100])
        for i in range(train_steps):


            """"
                第一轮,由于是没有权重,随机噪声
                再后我们对判别器进行训练之后,loss更新,生成器网络权重更新
            """
            images_train = self.x_train[np.random.randint(0,self.x_train.shape[0], size=batch_size), :, :, :] # 随机选取128张图像 [128,28,28,1]
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100]) #128,100 的随机【-1,1】之间的数
            images_fake = self.generator.predict(noise)   # 生成模型训练,图   [128,28,28,1]

            """
            图像保存 每5轮保存一次生成器所生成的image
            """
            if i%5==0:
                plt.figure(figsize=(24, 24))
                for j in range(16):
                    plt.subplot(4, 4, j + 1)
                    image = images_fake[j, :, :, :]
                    image = np.reshape(image, [28,28])
                    plt.imshow(image, cmap='gray')
                    plt.axis('off')
                    plt.tight_layout()
                filename = './g/img_{}'.format(i)
                # plt.savefig(filename)
                plt.close('all')


            """"
            在鉴别器的训练过程中,它显示为真实图像,并用于计算鉴别器损耗。
             它对来自生成器的真实和伪造图像进行分类,如果对任何图像进行了不正确分类,则鉴别器损失将对鉴别器进行惩罚。 
             通过反向传播,鉴别器更新其权重
             
             类似地,为生成器提供了噪声输入以生成伪图像。 这些图像被提供给鉴别器,并且发生器损失惩罚了发生器以产生鉴别器网络分类为伪造的样本。
              权重通过从鉴别器到生成器的反向传播进行更新
            """
            x = np.concatenate((images_train, images_fake)) #256*28*28*1  维度相加 数组拼接(将训练图片与生成的向量拼接), axis=0 按照行拼接。axis=1 按照列拼接,默认0
            print('4',x.shape)
            y = np.ones([2*batch_size, 1]) # 生成(256,1)的全是1的数组
            y[batch_size:, :] = 0 #  256*1   第128-256行的所有列全为0

            d_loss = self.discriminator.train_on_batch(x, y)#鉴别

            """
            核心
            """
            y = np.ones([batch_size, 1]) # 128*1
            noise = np.random.uniform(-1.0, 1.0, size=[batch_size, 100]) #128*100
            a_loss = self.adversarial.train_on_batch(noise, y)


            log_mesg = "%d: [D loss: %f, acc: %f]" % (i, d_loss[0], d_loss[1])
            log_mesg = "%s  [A loss: %f, acc: %f]" % (log_mesg, a_loss[0], a_loss[1])
            print(log_mesg)


            if save_interval>0:
                if (i+1)%save_interval==0:
                    self.plot_images(save2file=True, samples=noise_input.shape[0],\
                        noise=noise_input, step=(i+1))

    def plot_images(self, save2file=False, fake=True, samples=16, noise=None, step=0):
        filename = 'mnist.png'
        if fake:
            if noise is None:
                noise = np.random.uniform(-1.0, 1.0, size=[samples, 100])
            else:
                filename = "mnist_%d.png" % step
            images = self.generator.predict(noise)
        else:
            i = np.random.randint(0, self.x_train.shape[0], samples)
            images = self.x_train[i, :, :, :]

        plt.figure(figsize=(10,10))
        for i in range(images.shape[0]):
            plt.subplot(4, 4, i+1)
            image = images[i, :, :, :]
            image = np.reshape(image, [self.img_rows, self.img_cols])
            plt.imshow(image, cmap='gray')
            plt.axis('off')
        plt.tight_layout()

        # if save2file:
        #     plt.savefig(filename)
        #     plt.close('all')
        # else:
        #     plt.show()

if __name__ == '__main__':
    mnist_dcgan = MNIST_DCGAN()
    timer = ElapsedTimer()
    mnist_dcgan.train(train_steps=10000, batch_size=128, save_interval=1000)
    timer.elapsed_time()
    mnist_dcgan.plot_images(fake=True)
    mnist_dcgan.plot_images(fake=False, save2file=True)


结果展示

请添加图片描述
【参考文献】
https://www.cnblogs.com/dereen/p/gan.html
https://zhuanlan.zhihu.com/p/43047326
https://www.zhihu.com/question/306213462

  • 15
    点赞
  • 61
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值