GAN(生成对抗网络)生成FashionMNIST图像Pytorch实现(附完整代码!!!)

GAN介绍

背景

近年来,深度学习在很多领域的都取得了突破性进展,但大家似乎发现了这样的一个现实,即深度学习取得突破性进展的工作基本都是判别模型相关的,例如分类模型。2014年Goodfellow 等人启发自博弈论中的二人零和博弈,开创性地提出了生成对抗网络(Generative Adversarial Networks,GAN)。

原理介绍

我们先看名字,生成对抗网络,“生成“的意思就是是一个可以生成数据的模型,也即生成式模型。生成式模型是当前研究的热点,目前大模型大部分都是生成式模型,例如chatgpt可以文本生成文本(文本问答式),文本生成图像等。

那么“对抗”怎么理解呢?

其实“对抗”在GAN中很形象:在GAN中有一个判别器,有一个生成器。生成器的目的是生成假数据,例如生成不存在的图像,欺骗判别器(让判别器无法判别数据到底是真实存在的还是生成器生成的)。判别器的目的是判别真实数据和生成器生成的假数据。判别器和生成器在每一轮训练中同时学习,迭代权重。生成器学习让自己的数据尽量接近真实,而判别器学习让判别尽量准确,这两种模型互相学习,互相博弈,形成了一种对抗式的学习。

GAN就像罪犯和警察一样,罪犯制作假钞,通过不断学习制作假钞的工艺,让自己的假钞尽量逼真,以至于警察无法识别。而警察识别假钞,通过不断学习辨别能力,让识别假钞更准确。那么经过多轮的博弈(学习),罪犯制作假钞的能力提高了,警察识别假钞的能力也提高了。这样我们可以分别使用罪犯模型和警察模型去处理不同的任务。例如,我们可以使用警察模型识别假钞,使用罪犯模型去制造新的假币(日子越来越有判头了)。

GAN结构:

可以看到,两个重要的模块为生成器Generator(G)和判别器Discriminator(D)

对于生成器,生成器的输入为潜向量和随机噪声(其实只输入随机分布也可以),输出为假样本。

对于判别器,判别器的输入包括真实样本和生成器生成的假样本,以一定比例输入判别器中(一般是一半真实样本,一半假样本),输出为0或者1,0表示判别为假样本,1表示判别为真样本。

核心代码解释

损失定义:

# 先看判别器损失
real_label = descriminator(real_image)
d_loss_real = loss_function(real_label, torch.ones_like(real_label))

random_tensor = torch.randn(real_image.size(0), 100).to('cuda')
fake_image = generator(random_tensor)
fake_label = descriminator(fake_image.detach())
d_loss_fake = loss_function(fake_label, torch.zeros_like(fake_label))

d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()

# 生成器损失
fake_label = descriminator(fake_image)
g_loss = loss_function(fake_label, torch.ones_like(fake_label))

g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()

主要解释一下损失函数的定义,损失包括生成器损失和判别式损失,生成器和判别器各自更新。

对于判别器,他的损失应该包括两部分,第一部分是真实样本标签和1的损失(真实样本应该被判别器判为1),第二部分是假样本标签和0的损失(假样本应该被判别器判为0).

对于生成器,他的损失不好定义,我们借助判别式来定义。生成器生成的样本应该被判别器判为0,但是生成器希望生成的数据尽量真实,也就是希望判别器犯错,把生成的样本判为1,因此损失可以定义为假样本标签和1的损失。

完整代码

生成模型:

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.optim import Adam
import argparse
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt

def download():
    # 将图片转化为张量以及归一化处理
    Trans = torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize(mean=[0.5], std=[0.5])])

    # 下载MNIST对应的训练和测试数据集
    train_data = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=Trans,
    )

    test_data = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=Trans,
    )

    train_Dataloader = DataLoader(train_data,batch_size=64)
    test_Dataloader = DataLoader(test_data,batch_size=999999)

    return train_Dataloader, test_Dataloader, train_data, test_data


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.judge = nn.Sequential(nn.Linear(28*28,512), nn.ReLU(), nn.Linear(512,256), nn.ReLU(), nn.Linear(256,32), nn.ReLU(), nn.Linear(32,1), nn.Sigmoid())

    def forward(self,image):
        y = self.judge(image)
        return y



class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.generate = nn.Sequential(nn.Linear(100,256), nn.ReLU(), nn.Linear(256,512), nn.ReLU(), nn.Linear(512,28*28))

    def forward(self, x):
        image = self.generate(x)
        return image


def train(descriminator, generator, d_optimizer, g_optimizer, train_dataloader, loss_function):
    for real_image,_ in tqdm(train_dataloader):
        real_image = real_image.to('cuda')
        real_image = real_image.reshape(-1,28*28)

        # 先看判别器损失
        real_label = descriminator(real_image)
        d_loss_real = loss_function(real_label, torch.ones_like(real_label))

        random_tensor = torch.randn(real_image.size(0),100).to('cuda')
        fake_image = generator(random_tensor)
        fake_label = descriminator(fake_image.detach())
        d_loss_fake = loss_function(fake_label, torch.zeros_like(fake_label))

        d_loss = d_loss_real + d_loss_fake
        d_optimizer.zero_grad()
        d_loss.backward()
        d_optimizer.step()


        # 生成器损失
        fake_label = descriminator(fake_image)
        g_loss = loss_function(fake_label, torch.ones_like(fake_label))

        g_optimizer.zero_grad()
        g_loss.backward()
        g_optimizer.step()



if __name__ == "__main__":
    train_dataloader, test_dataloader, train_data, test_data = download()

    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    descriminator = Discriminator().to(device)
    generator = Generator().to(device)

    loss_function = nn.BCELoss()

    d_optimizer = Adam(descriminator.parameters(), lr=0.001)
    g_optimizer = Adam(generator.parameters(), lr=0.001)

    epochs = 30
    for epoch in range(epochs):
        print("training epoch:",epoch)
        train(descriminator, generator, d_optimizer, g_optimizer, train_dataloader, loss_function)


    torch.save(generator.state_dict(),'./generator.pth')
    torch.save(descriminator.state_dict(),'./descriminator.pth')
    print("模型保存成功")

    new_generator = Generator()
    new_generator.load_state_dict(torch.load('./generator.pth'))
    print("生成模型重载成功")

使用生成器生成样本:

import torch
from  VAE import VAE
import  torchvision
import numpy as np
import matplotlib.pyplot as plt
from GAN import Generator

new_generator = Generator()
new_generator.load_state_dict(torch.load('./generator.pth'))
print("生成模型重载成功")

with torch.no_grad():
    x = torch.randn(16, 100)
    fake = new_generator(x)

    fake = fake.reshape(-1, 1, 28, 28)
    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)

    # 将图像张量转换为 NumPy 数组
    img_grid_fake_np = img_grid_fake.cpu().numpy()
    img_grid_fake_np = np.transpose(img_grid_fake_np, (1, 2, 0))

    # 使用 matplotlib 显示图像
    plt.figure(figsize=(10, 10))
    plt.imshow(img_grid_fake_np)
    plt.axis('off')  # 不显示坐标轴
    plt.show()

运行结果

是不是很神奇,生成了和数据集类似的图像,这些图像并不真实存在。

需要注意的是,图片比较模糊,表示生成的效果欠佳。这是因为只训练了30轮,一步到位也没有进行调参。读者可以增加训练轮数以及通过网格搜索的形式改变超参数,会得到效果更好的图像。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值