昇思25天学习打卡营第24天|GAN图像生成

使用背景

        生成式对抗网络(GANs)是近年来复杂分布上无监督学习最具前景的方法之一。它通过生成器和判别器的博弈学习,实现了从噪声中生成逼真图像的能力。GANs广泛应用于图像生成、图像修复、风格转换、超分辨率等领域。

原理

        GANs 由两个主要组件组成:

  • 生成器(Generator):生成器的任务是从随机噪声中生成逼真的图像。
  • 判别器(Discriminator):判别器的任务是区分真实图像和生成图像。

        GAN的训练是一个零和博弈过程,其中生成器试图欺骗判别器,使其认为生成的图像是真实的,而判别器则努力提高识别生成图像的能力。通过反复迭代训练,生成器生成的图像会越来越逼真。

实现代码

数据加载与预处理

        使用MNIST手写数字数据集进行训练,并进行数据加载和预处理。

import numpy as np
import mindspore.dataset as ds
import matplotlib.pyplot as plt

batch_size = 64
latent_size = 100

train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')

def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False, num_samples=10000)
    mnist_ds = dataset1.map(
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])
    mnist_ds = mnist_ds.batch(batch_size, True)
    return mnist_ds

mnist_ds = data_load(train_dataset)
iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
生成器

        定义生成器模型,将随机噪声映射到图像空间。

from mindspore import nn
import mindspore.ops as ops

img_size = 28

class Generator(nn.Cell):
    def __init__(self, latent_size):
        super(Generator, self).__init__()
        self.model = nn.SequentialCell([
            nn.Dense(latent_size, 128),
            nn.ReLU(),
            nn.Dense(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dense(256, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dense(512, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Dense(1024, img_size * img_size),
            nn.Tanh()
        ])

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, img_size, img_size))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')
判别器

        定义判别器模型,区分输入图像是真实的还是生成的。

class Discriminator(nn.Cell):
    def __init__(self):
        super().__init__()
        self.model = nn.SequentialCell([
            nn.Dense(img_size * img_size, 512),
            nn.LeakyReLU(),
            nn.Dense(512, 256),
            nn.LeakyReLU(),
            nn.Dense(256, 1),
            nn.Sigmoid()
        ])

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')
损失函数和优化器

        定义GAN的损失函数和优化器。

from mindspore import nn

lr = 0.0002

adversarial_loss = nn.BCELoss(reduction='mean')
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')
模型训练

        定义训练过程,包括训练生成器和判别器。

import os
import time
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 12
batch_size = 64

checkpoints_path = "./result/checkpoints"
image_path = "./result/images"

def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g

def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)
    return loss_d, loss_g

def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

os.makedirs(checkpoints_path, exist_ok=True)
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 10:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")
    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))
    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

模型推理

        加载训练好的生成器模型并生成图像。

import mindspore as ms
import matplotlib.pyplot as plt
import numpy as np

test_ckpt = './result/checkpoints/Generator11.ckpt'
parameter = ms.load_checkpoint(test_ckpt)
ms.load_param_into_net(net_g, parameter)

test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()

fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):
    fig.add_subplot(5, 5, i + 1)
    plt.axis("off")
    plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

结果

学习心得:通过这次对GAN图像生成模型的实现,我们发现以下几点非常关键:GAN的对抗训练机制使生成器和判别器在相互博弈中不断提升生成图像的质量;数据预处理的重要性不可忽视,归一化和增强操作显著提高了模型的训练效果;合理的模型结构设计,包括使用适当的网络层数和激活函数,如ReLU和LeakyReLU,有助于提高生成器和判别器的性能;训练过程的监控,通过定期保存模型权重和生成图像,能够直观地评估模型的训练进展,并及时调整训练参数。这些心得有助于我们更好地理解和应用GAN技术,在实际项目中有效提升模型的生成能力。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

  • 8
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

会飞的Anthony

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值