通过GAN生成手写数据集

生成对抗网络(GAN)的主要设计目的是生成逼真的数据样本,而不是用于分类任务。然而,有一些扩展和变体可以将 GAN 与分类任务结合起来。

以下是一些将 GAN 与分类任务结合的方法:

  1. 生成对抗网络(GAN)和条件生成对抗网络(cGAN):cGAN 是 GAN 的一种变体,它引入了条件信息作为生成器和判别器的输入。这样一来,cGAN 可以按照给定的条件生成相关联的样本,这个条件可以是类别标签。因此,cGAN 可以在生成过程中同时控制生成的样本的类别。

  2. 生成对抗网络(GAN)和生成分类模型(GCM):GCM 是将 GAN 的生成器用作分类器的一种方法。在这种情况下,生成器训练用于生成逼真的样本,但同时也可以用作分类器,对生成的样本进行分类。

  3. GAN 可以用于数据增强:生成器可以用于生成合成的训练样本,从而扩充训练数据集。这种扩充可以改善分类器的泛化能力。

  4. 生成对抗网络(GAN)用于特征生成:GAN 可以用于生成具有特定属性或特征的样本,这些样本可以用于训练分类器,以提高分类性能。

尽管可以将 GAN 与分类任务结合使用,但一般来说,GAN 更适合于生成任务,而其他模型(如卷积神经网络,支持向量机等)更适合于分类任务。

学习测试代码

"""
# -*- coding: utf-8 -*-
# @Time    : 2023/10/17 8:43
# @Author  : 王摇摆
# @FileName: gan.py
# @Software: PyCharm
# @Blog    :https://blog.csdn.net/weixin_44943389?type=blog
"""

import matplotlib.pyplot as plt
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten, Input
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam


# 定义生成器
def build_generator(latent_dim):
    model = Sequential([
        Dense(128, input_dim=latent_dim),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(256),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(512),
        LeakyReLU(alpha=0.2),
        BatchNormalization(momentum=0.8),
        Dense(784, activation='sigmoid'),
        Reshape((28, 28, 1))
    ])
    return model


# 定义判别器
def build_discriminator():
    model = Sequential([
        Flatten(input_shape=(28, 28, 1)),
        Dense(512),
        LeakyReLU(alpha=0.2),
        Dense(256),
        LeakyReLU(alpha=0.2),
        Dense(1, activation='sigmoid')
    ])
    return model


# 定义GAN
def build_gan(generator, discriminator):
    z = Input(shape=(latent_dim,))
    img = generator(z)
    validity = discriminator(img)
    model = Model(z, validity)
    return model


# 设置随机种子
np.random.seed(1000)
latent_dim = 100

# 初始化并编译GAN模型
generator = build_generator(latent_dim)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5),
                      metrics=['accuracy'])

discriminator.trainable = False
gan = build_gan(generator, discriminator)
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))

# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=-1)


# 定义训练函数
def train_gan(epochs=1, batch_size=128):
    batch_count = X_train.shape[0] // batch_size

    for e in range(epochs):
        for _ in range(batch_count):
            noise = np.random.normal(0, 1, size=[batch_size, latent_dim])
            generated_images = generator.predict(noise)
            image_batch = X_train[np.random.randint(0, X_train.shape[0], size=batch_size)]

            X = np.concatenate([image_batch, generated_images])
            y_dis = np.zeros(2 * batch_size)
            y_dis[:batch_size] = 0.9

            discriminator.trainable = True
            d_loss = discriminator.train_on_batch(X, y_dis)

            noise = np.random.normal(0, 1, size=[batch_size, latent_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            g_loss = gan.train_on_batch(noise, y_gen)

        print(f"Epoch {e} - Discriminator Loss: {d_loss[0]}, Generator Loss: {g_loss}")

        if e % 10 == 0:
            plot_generated_images(e, generator)


# 定义生成图像的函数
def plot_generated_images(epoch, generator, examples=10, dim=(1, 10), figsize=(10, 1)):
    noise = np.random.normal(0, 1, size=[examples, latent_dim])
    generated_images = generator.predict(noise)
    generated_images = generated_images.reshape(examples, 28, 28)

    plt.figure(figsize=figsize)
    for i in range(generated_images.shape[0]):
        plt.subplot(dim[0], dim[1], i + 1)
        plt.imshow(generated_images[i], interpolation='nearest', cmap='gray_r')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'gan_generated_image_epoch_{epoch}.png')


# 训练GAN模型
train_gan(epochs=200, batch_size=128)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
你可以使用生成对抗网络GAN生成数据集,并对其进行可视化。下面是一个简单的示例代码,使用GAN生成手写数字数据集(MNIST)并进行可视化: ```python import numpy as np import matplotlib.pyplot as plt from tensorflow.keras.datasets import mnist from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization # 加载MNIST数据集 (train_images, _), (_, _) = mnist.load_data() # 数据预处理 train_images = train_images.astype('float32') train_images = (train_images - 127.5) / 127.5 # 将像素值从[0, 255]缩放到[-1, 1] train_images = train_images.reshape(train_images.shape[0], 784) # 设置生成器模型 generator = Sequential() generator.add(Dense(256, input_shape=(100,))) generator.add(LeakyReLU(alpha=0.2)) generator.add(BatchNormalization()) generator.add(Dense(512)) generator.add(LeakyReLU(alpha=0.2)) generator.add(BatchNormalization()) generator.add(Dense(784, activation='tanh')) # 生成样本 noise = np.random.normal(0, 1, size=[10, 100]) # 生成10个随机噪声样本 generated_images = generator.predict(noise) # 可视化生成的样本 fig, axes = plt.subplots(1, 10, figsize=(10, 1)) for i, ax in enumerate(axes): ax.imshow(generated_images[i].reshape(28, 28), cmap='gray') ax.axis('off') plt.show() ``` 这段代码使用Keras库构建了一个简单的生成器模型,然后使用随机噪声作为输入,生成10个手写数字样本,并将其可视化显示出来。你可以根据需要调整生成的样本数量和模型架构。记得在运行代码之前安装必要的库和数据集

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

王摇摆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值