保姆级讲解生成对抗网络GAN,及原始GAN的torch复现

65 篇文章 4 订阅
30 篇文章 4 订阅

原始GAN的torch复现:

# coding:utf-8
# @Email: wangguisen@infinities.com.cn
# @Time: 2022/11/11 10:44 下午
# @File: GAN_demo.py
'''
基于手写数字识别数据的 GAN demo
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms

'''
使用手写数字识别为例
现将数据归一化到(-1,1)
  其和GAN的训练技巧有关,对于生成器最后使用tanh激活,tanh的取值范围就是(-1,1),
  为了方便生成的图片和输入噪声取值范围相同,所以将输入归一化到(-1,1)
'''
# 对数据归一化(-1,1)
transform = transforms.Compose([
    transforms.ToTensor(),            
    transforms.Normalize(0.5, 0.5)    
])

train_ds = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True)

dataloader = torch.utils.data.DataLoader(train_ds, batch_size=64, shuffle=True)

# imgs, labels = next(iter(dataloader))
# print(imgs.shape)

'''   定义生成器   '''
'''
基于这个例子,输入为长度100的正态分布噪声
输出维度为(1, 28, 28)
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.linears = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )

    def forward(self, x):
        # x 为长度为100的noise
        out = self.linears(x)
        out = out.view(-1, 28, 28, 1)
        return out

'''   定义判别器   '''
'''
输入维度为(1,,28,,28)
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.linears = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28*28)
        return self.linears(x)

'''   初始化模型、优化器、损失   '''
device = 'cuda' if torch.cuda.is_available() else 'cpu'

gen = Generator().to(device)
dis = Discriminator().to(device)

g_optim = torch.optim.Adam(gen.parameters(), lr=0.0001)
d_optim = torch.optim.Adam(dis.parameters(), lr=0.0001)

loss_fn = nn.BCELoss()


'''   绘图函数  '''
def gen_img_plot(net, test_input):
    prediction = np.squeeze(net(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(16):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i] + 1) / 2)
        plt.axis('off')
    plt.show()

test_input = torch.randn(16, 100, device=device)

'''   训练   '''
D_loss = []
G_loss = []
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count = len(dataloader)
    for step, (img, label) in enumerate(dataloader):
        img = img.to(device)
        size = img.size(0)
        random_noise = torch.randn(size, 100, device=device)

        '''判别器优化'''
        d_optim.zero_grad()
        # 判别器输入真实的图片,得到对真实图片的预测结果
        real_output = dis(img)
        # 判别器在真实图片上的损失
        d_real_loss = loss_fn(real_output, torch.ones_like(real_output))    # 希望判别器将真实的数据判别为全1
        d_real_loss.backward()

        # 判别器输入生成的图片,得到判别器在生成图像上的损失
        gen_img = gen(random_noise)
        fake_output = dis(gen_img.detach())   # 对于生成图片产生的损失,我们的优化目标是判别器,希望fake_output被判定为0,来优化判别器,所以要截断梯度,detach会得到一个没有tensor的梯度
        d_fake_loss = loss_fn(fake_output, torch.zeros_like(fake_output))   # 希望判别器将生成的数据判别为全0
        d_fake_loss.backward()

        # 判别器总损失
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()

        '''生成器优化'''
        g_optim.zero_grad()
        fake_output = dis(gen_img)   # 优化生成器,所以不用截断 detach
        # 对于生成器,希望生成的图片判定为1
        g_loss = loss_fn(fake_output, torch.ones_like(fake_output))   # 生成器的损失
        g_loss.backward()
        g_optim.step()

        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    with torch.no_grad():
        d_epoch_loss /= count
        g_epoch_loss /= count
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        print('Epoch: ', epoch)
        gen_img_plot(net=gen, test_input=test_input)

我们来看一下第一个epoch和第20个epoch的结果:

第一个epoch如下,可见还是很模糊的:

在这里插入图片描述

第20个epoch如下,已经清晰了很多了,并且有了数字的轮廓:

在这里插入图片描述

生成对抗网络(Generative Adversarial Networks, GANs)是由Ian Goodfellow等人在2014年提出的深度学习模型架构[^4]。GAN由两个主要组成部分组成:生成器(Generator)和判别器(Discriminator)。它们通过一种零和博弈的方式相互作用。 **生成器**:尝试学习从随机噪声(通常是高斯分布)中生成与训练数据相似的新样本。它的目标是尽可能地欺骗判别器,使其误认为生成的数据是真实的。 **判别器**:负责区分真实数据和生成的数据。它试图准确地判断输入是来自训练数据还是生成器。 GAN的工作流程如下: 1. **训练过程**:生成器接收随机噪声作为输入并生成假样本,判别器则对这些样本进行分类,判断是真样本还是假样本。生成器根据判别器的反馈更新参数以提高生成能力,判别器也相应地调整其参数以提高识别能力。 2. **对抗迭代**:这两个模型交替优化,直到达到平衡状态,即生成器可以生成足够逼真的样本,使得判别器无法准确区分开来。 **示例代码**(简化版): ```python import torch.nn as nn # 假设我们有简单的生成器和判别器结构 class Generator(nn.Module): def __init__(self): super(Generator, self).__init__() # ... def forward(self, noise): # 生成器的前向传播 pass class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() # ... def forward(self, input): # 判别器的前向传播 pass # 初始化并训练GAN generator = Generator() discriminator = Discriminator() for _ in range(num_epochs): fake_data = generator(noise) real_labels = torch.ones(batch_size) fake_labels = torch.zeros(batch_size) discriminator.zero_grad() d_loss_real = discriminator(real_data).mean() d_loss_fake = discriminator(fake_data.detach()).mean() d_loss = (d_loss_real + d_loss_fake).backward() discriminator_optimizer.step() generator.zero_grad() g_loss = discriminator(generator(noise)).mean().backward() generator_optimizer.step() ```
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WGS.

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值