生成对抗网络GAN---生成mnist手写数字图像示例(附代码)

Ian J. Goodfellow等人于2014年在论文Generative Adversarial Nets中提出了一个通过对抗过程估计生成模型的新框架。框架中同时训练两个模型:一个生成模型(generative model)G,用来捕获数据分布;一个判别模型(discriminative model)D,用来估计样本来自于训练数据的概率。G的训练过程是将D错误的概率最大化。可以证明在任意函数G和D的空间中,存在唯一的解决方案,使得G重现训练数据分布,而D=0.5。

生成对抗网络(GAN,Generative Adversarial Networks)的基本原理很简单:假设有两个网络,生成网络G和判别网络D。生成网络G接受一个随机的噪声z并生成图片,记为G(z);判别网络D的作用是判别一张图片x是否真实,对于输入x,D(x)是x为真实图片的概率。在训练过程中, 生成器努力让生成的图片更加真实从而使得判别器无法辨别图像的真假,而D的目标就是尽量把分辨出真实图片和生成网络G产出的图片,这个过程就类似于二人博弈,G和D构成了一个动态的“博弈过程”。随着时间的推移,生成器和判别器在不断地进行对抗,最终两个网络达到一个动态平衡:生成器生成的图像G(z)接近于真实图像分布,而判别器识别不出真假图像,即D(G(z))=0.5。最后,我们就可以得到一个生成网络G,用来生成图片。

对于GAN更加直观的理解:生成模型可以被看做是一个伪造团队,试图生产假币并且在不被发现的情况下使用, 而判别模型则类似于警察,尝试检查是否为假币。伪造团队的目的是生产出警察识别不出的假币,而警察则是想更加精确地识别出假币,因此在这个游戏中,两个团队因为各自目的而不断改进它们的方法直到伪造团队生产的假币警察分辨不出来。

 上面讲述生成对抗网络的基本原理, 为了能够更深此理解GAN,下面我们使用GAN来生成MNIST数据集。

import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data
from matplotlib import pyplot as plt

BATCH_SIZE = 64
UNITS_SIZE = 128
LEARNING_RATE = 0.001
EPOCH = 300
SMOOTH = 0.1

mnist = input_data.read_data_sets('/mnist_data/', one_hot=True)


# 生成模型
def generatorModel(noise_img, units_size, out_size, alpha=0.01):
    with tf.variable_scope('generator'):
        FC = tf.layers.dense(noise_img, units_size)
        reLu = tf.nn.leaky_relu(FC, alpha)
        drop = tf.layers.dropout(reLu, rate=0.2)
        logits = tf.layers.dense(drop, out_size)
        outputs = tf.tanh(logits)
        return logits, outputs

# 判别模型
def discriminatorModel(images, units_size, alpha=0.01, reuse=False):
    with tf.variable_scope('discriminator', reuse=reuse):
        FC = tf.layers.dense(images, units_size)
        reLu = tf.nn.leaky_relu(FC, alpha)
        logits = tf.layers.dense(reLu, 1)
        outputs = tf.sigmoid(logits)
        return logits, outputs

# 损失函数
"""
判别器的目的是:
1. 对于真实图片,D要为其打上标签1
2. 对于生成图片,D要为其打上标签0
生成器的目的是:对于生成的图片,G希望D打上标签1
"""
def loss_function(real_logits, fake_logits, smooth):
    # 生成器希望判别器判别出来的标签为1; tf.ones_like()创建一个将所有元素都设置为1的张量
    G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
                                                                    labels=tf.ones_like(fake_logits)*(1-smooth)))
    # 判别器识别生成器产出的图片,希望识别出来的标签为0
    fake_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_logits,
                                                                       labels=tf.zeros_like(fake_logits)))
    # 判别器判别真实图片,希望判别出来的标签为1
    real_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=real_logits,
                                                                       labels=tf.ones_like(real_logits)*(1-smooth)))
    # 判别器总loss
    D_loss = tf.add(fake_loss, real_loss)
    return G_loss, fake_loss, real_loss, D_loss

# 优化器
def optimizer(G_loss, D_loss, learning_rate):
    train_var = tf.trainable_variables()
    G_var = [var for var in train_var if var.name.startswith('generator')]
    D_var = [var for var in train_var if var.name.startswith('discriminator')]
    # 因为GAN中一共训练了两个网络,所以分别对G和D进行优化
    G_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, var_list=G_var)
    D_optimizer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss, var_list=D_var)
    return G_optimizer, D_optimizer


# 训练
def train(mnist):
    image_size = mnist.train.images[0].shape[0]
    real_images = tf.placeholder(tf.float32, [None, image_size])
    fake_images = tf.placeholder(tf.float32, [None, image_size])

    #调用生成模型生成图像G_output
    G_logits, G_output = generatorModel(fake_images, UNITS_SIZE, image_size)
    # D对真实图像的判别
    real_logits, real_output = discriminatorModel(real_images, UNITS_SIZE)
    # D对G生成图像的判别
    fake_logits, fake_output = discriminatorModel(G_output, UNITS_SIZE, reuse=True)
    # 计算损失函数
    G_loss, real_loss, fake_loss, D_loss = loss_function(real_logits, fake_logits, SMOOTH)
    # 优化
    G_optimizer, D_optimizer = optimizer(G_loss, D_loss, LEARNING_RATE)

    saver = tf.train.Saver()
    step = 0
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        for epoch in range(EPOCH):
            for batch_i in range(mnist.train.num_examples // BATCH_SIZE):
                batch_image, _ = mnist.train.next_batch(BATCH_SIZE)
                # 对图像像素进行scale,tanh的输出结果为(-1,1)
                batch_image = batch_image * 2 -1
                # 生成模型的输入噪声
                noise_image = np.random.uniform(-1, 1, size=(BATCH_SIZE, image_size))
                #
                session.run(G_optimizer, feed_dict={fake_images:noise_image})
                session.run(D_optimizer, feed_dict={real_images: batch_image, fake_images: noise_image})
                step = step + 1
            # 判别器D的损失
            loss_D = session.run(D_loss, feed_dict={real_images: batch_image, fake_images:noise_image})
            # D对真实图片
            loss_real =session.run(real_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
            # D对生成图片
            loss_fake = session.run(fake_loss, feed_dict={real_images: batch_image, fake_images: noise_image})
            # 生成模型G的损失
            loss_G = session.run(G_loss, feed_dict={fake_images: noise_image})
            print('epoch:', epoch, 'loss_D:', loss_D, ' loss_real', loss_real, ' loss_fake', loss_fake, ' loss_G', loss_G)
            model_path = os.getcwd() + os.sep + "mnist.model"
            saver.save(session, model_path, global_step=step)

def main(argv=None):
    train(mnist)

if __name__ == '__main__':
    tf.app.run()

上述是训练模型,下面是测试模型,依据训练模型训练的参数。generatorImage函数生成手写字体图片, 在这里显示了25张图片。 生成图像如下图1所示,还能够大略猜出生成的图片中的数字。

import tensorflow as tf
import numpy as np
from matplotlib import pyplot as plt
import pickle
import mnist_GAN

UNITS_SIZE = mnist_GAN.UNITS_SIZE

def generatorImage(image_size):
    sample_images = tf.placeholder(tf.float32, [None, image_size])
    G_logits, G_output = mnist_GAN.generatorModel(sample_images, UNITS_SIZE, image_size)
    saver = tf.train.Saver()
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        saver.restore(session, tf.train.latest_checkpoint('.'))
        sample_noise = np.random.uniform(-1, 1, size=(25, image_size))
        samples = session.run(G_output, feed_dict={sample_images:sample_noise})
    with open('samples.pkl', 'wb') as f:
        pickle.dump(samples, f)

def show():
    with open('samples.pkl', 'rb') as f:
        samples = pickle.load(f)
    fig, axes = plt.subplots(figsize=(7, 7), nrows=5, ncols=5, sharey=True, sharex=True)
    for ax, image in zip(axes.flatten(), samples):
        ax.xaxis.set_visible(False)
        ax.yaxis.set_visible(False)
        ax.imshow(image.reshape((28, 28)), cmap='Greys_r')
    plt.show()

def main(argv=None):
    image_size = mnist_GAN.mnist.train.images[0].shape[0]
    generatorImage(image_size)
    show()

if __name__ == '__main__':
    tf.app.run()
图1. 生成图片展示

上述基于MNIST数据集构造了一个简单的GAN模型,对于生成模型和判别模型,仅仅使用了简单的神经网络,对于图像的处理,卷积神经网络更胜一筹,如果将生成模型和判别模型改为深度卷积网络,那么生成更加清晰的图片。 而且目前也有各种GAN变体,后续慢慢整理。

参考博客:基于GAN的mnist训练集图片生成神经网络实现_gan训练集_lpty的博客-CSDN博客

  • 27
    点赞
  • 183
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 13
    评论
以下是一个基于GANmnist手写数字生成的PyTorch代码示例: ```python import torch import torch.nn as nn from torchvision.datasets import MNIST from torchvision.transforms import ToTensor from torch.utils.data import DataLoader # 定义生成器 class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x # 定义判别器 class Discriminator(nn.Module): def __init__(self, input_size, hidden_size): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, 1) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.sigmoid(self.fc1(x)) x = self.sigmoid(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # 定义超参数 input_size = 100 hidden_size = 256 output_size = 784 batch_size = 128 num_epochs = 200 # 加载MNIST数据集 train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 初始化网络 G = Generator(input_size, hidden_size, output_size) D = Discriminator(output_size, hidden_size) # 定义损失函数和优化器 criterion = nn.BCELoss() lr = 0.0002 G_optimizer = torch.optim.Adam(G.parameters(), lr=lr) D_optimizer = torch.optim.Adam(D.parameters(), lr=lr) # 定义真实和假的标签 real_label = torch.ones(batch_size, 1) fake_label = torch.zeros(batch_size, 1) # 训练网络 for epoch in range(num_epochs): for i, (images, _) in enumerate(train_loader): # 定义真实和假的图像 real_images = images.view(batch_size, -1) z = torch.randn(batch_size, input_size) fake_images = G(z) # 训练判别器 D_real_loss = criterion(D(real_images), real_label) D_fake_loss = criterion(D(fake_images.detach()), fake_label) D_loss = D_real_loss + D_fake_loss D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # 训练生成器 G_loss = criterion(D(fake_images), real_label) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() # 打印损失 if (i+1) % 100 == 0: print('Epoch [{}/{}], Step [{}/{}], D Loss: {:.4f}, G Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item())) # 保存模型 torch.save(G.state_dict(), 'generator.pth') ``` 在训练完成后,可以使用生成器来生成新的手写数字图像,例如: ```python import matplotlib.pyplot as plt import numpy as np # 加载生成器 G = Generator(input_size, hidden_size, output_size) G.load_state_dict(torch.load('generator.pth')) # 生成图像 z = torch.randn(1, input_size) fake_image = G(z).detach().numpy() fake_image = np.reshape(fake_image, (28, 28)) # 显示图像 plt.imshow(fake_image, cmap='gray') plt.show() ``` 这样就可以生成一个随机的手写数字图像了。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

马鹤宁

谢谢

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值