pytorch深度学习入门(11)之-GAN生成对抗神经网络

概述:

GAN网络是一种对抗神经网络,由生成器和判别器两个网络组成。生成器的任务是生成与真实数据相似的虚假数据,而判别器的任务是判断输入数据是真实数据还是虚假数据。在训练过程中,生成器和判别器会相互对抗,生成器会努力生成与真实数据相似的虚假数据来欺骗判别器,而判别器则会尽可能提高自己的判断准确性。GAN网络的训练过程通常包括以下几个步骤:生成器生成虚假数据,并将其输入到判别器中进行判断。判别器对输入数据进行分类,并计算分类误差。根据分类误差,更新判别器的参数,以提高其分类准确率。生成器根据判别器的反馈信息,更新自身的参数,以生成更加逼真的虚假数据。GAN网络在图像生成、图像修复、语音识别等领域具有广泛的应用前景。

什么是 DCGAN?
DCGAN 是上述 GAN 的直接扩展,不同之处在于它分别在判别器和生成器中明确使用卷积层和卷积转置层。它首先由 Radford 等人描述。等人。在论文《使用深度卷积生成对抗网络进行无监督表示学习》中。判别器由跨步 卷积 层、批归一化 层和 LeakyReLU 激活组成。输入是 3x64x64 输入图像,输出是输入来自真实数据分布的标量概率。生成器由 卷积转置 层、批归一化层和 ReLU激活层组成。输入是一个潜在向量z,它是根据标准正态分布绘制的,输出是 3x64x64 RGB 图像。跨步的转换转置层允许将潜在向量转换为与图像形状相同的体积。在论文中,作者还给出了一些关于如何设置优化器、如何计算损失函数以及如何初始化模型权重的提示,所有这些都将在接下来的章节中进行解释。

#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results

输出:

Random Seed:  999

输入

让我们为运行定义一些输入:

dataroot- 数据集文件夹根目录的路径。我们将在下一节中详细讨论数据集。

workers- 用于加载数据的工作线程数DataLoader。

batch_size- 训练中使用的批量大小。DCGAN 论文使用的批量大小为 128。

image_size- 用于训练的图像的空间大小。此实现默认为 64x64。如果需要其他尺寸,则必须改变D和G的结构。请参阅 此处了解更多详细信息。

nc- 输入图像中的颜色通道数。对于彩色图像,该值为 3。

nz- 潜在向量的长度。

ngf- 与生成器所承载的特征图的深度相关。

ndf- 设置通过鉴别器传播的特征图的深度。

num_epochs- 要运行的训练周期数。更长时间的训练可能会带来更好的结果,但也会花费更长的时间。

lr- 训练的学习率。正如 DCGAN 论文中所述,这个数字应该是 0.0002。

beta1- Adam 优化器的 beta1 超参数。正如论文中所述,这个数字应该是 0.5。

ngpu- 可用 GPU 的数量。如果为 0,代码将在 CPU 模式下运行。如果该数字大于 0,它将在该数量的 GPU 上运行。

# Root directory for dataset
dataroot = "data/celeba"

# Number of workers for dataloader
workers = 2

# Batch size during training
batch_size = 128

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 100

# Size of feature maps in generator
ngf = 64

# Size of feature maps in discriminator
ndf = 64

# Number of training epochs
num_epochs = 5

# Learning rate for optimizers
lr = 0.0002

# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1

数据

在本教程中,我们将使用Celeb-A Faces 数据集,该数据集可以在链接站点或Google Drive中下载。数据集将下载为名为img_align_celeba.zip. 下载后,创建一个名为 的目录celeba并将 zip 文件解压到该目录中。然后,dataroot将此笔记本的输入设置为celeba您刚刚创建的目录。最终的目录结构应该是:

/path/to/celeba
    -> img_align_celeba
        -> 188242.jpg
        -> 173822.jpg
        -> 284702.jpg
        -> 537394.jpg
           ...

这是重要的一步,因为我们将使用ImageFolder 数据集类,这要求数据集根文件夹中有子目录。现在,我们可以创建数据集,创建数据加载器,设置要运行的设备,最后可视化一些训练数据。

# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

训练图像
输出:

<matplotlib.image.AxesImage object at 0x7f52efaa4550>

执行

设置好输入参数并准备好数据集后,我们现在可以开始实施了。我们将从权重初始化策略开始,然后详细讨论生成器、判别器、损失函数和训练循环。

权重初始化
在 DCGAN 论文中,作者指定所有模型权重均应从正态分布随机初始化mean=0, , stdev=0.02。该weights_init函数采用初始化的模型作为输入,并重新初始化所有卷积、卷积转置和批量归一化层以满足此标准。该函数在初始化后立即应用于模型。

# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

生成器

生成器G,旨在映射潜在空间向量(z) 到数据空间。由于我们的数据是图像,因此转换 z到数据空间意味着最终创建一个与训练图像大小相同的 RGB 图像(即 3x64x64)。在实践中,这是通过一系列跨步二维卷积转置层来完成的,每个层都与一个 2d 批归一化层和一个 relu 激活配对。生成器的输出通过 tanh 函数返回到输入数据范围[-1,1][ - 1 ,1 ]。值得注意的是,在卷积转置层之后存在批归一化函数,因为这是 DCGAN 论文的关键贡献。这些层有助于训练期间梯度的流动。DCGAN 论文中的生成器图像如下所示。

dcgan_生成器
nz请注意,我们在输入部分( 、ngf和 )中设置的输入如何nc影响代码中的生成器架构。nz是 z 输入向量的长度,ngf与通过生成器传播的特征图的大小相关,并且nc是输出图像中的通道数(对于 RGB 图像设置为 3)。下面是生成器的代码。

# Generator Code

class Generator(nn.Module):
    def __init__(self, ngpu):
        super(Generator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is Z, going into a convolution
            nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # state size. ``(ngf*8) x 4 x 4``
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # state size. ``(ngf*4) x 8 x 8``
            nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # state size. ``(ngf*2) x 16 x 16``
            nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # state size. ``(ngf) x 32 x 32``
            nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
            # state size. ``(nc) x 64 x 64``
        )

    def forward(self, input):
        return self.main(input)

现在,我们可以实例化生成器并应用该weights_init 函数。查看打印的模型以了解生成器对象的结构。

# Create the generator
netG = Generator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
#  to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)

# Print the model
print(netG)

输出:

Generator(
  (main): Sequential(
    (0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (11): ReLU(inplace=True)
    (12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (13): Tanh()
  )
)

鉴别器
如前所述,判别器D,是一个二元分类网络,它将图像作为输入并输出输入图像是真实的(而不是假的)的标量概率。这里D获取 3x64x64 输入图像,通过一系列 Conv2d、BatchNorm2d 和 LeakyReLU 层对其进行处理,并通过 Sigmoid 激活函数输出最终概率。如果问题需要,可以使用更多层来扩展该架构,但使用跨步卷积、BatchNorm 和 LeakyReLU 具有重要意义。DCGAN 论文提到,使用跨步卷积而不是池化来下采样是一种很好的做法,因为它可以让网络学习自己的池化函数。批归一化和泄漏 ReLU 函数还可以促进健康的梯度流,这对于两者的学习过程都至关重要 G和D。

鉴别器代码

class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.ngpu = ngpu
        self.main = nn.Sequential(
            # input is ``(nc) x 64 x 64``
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf) x 32 x 32``
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*2) x 16 x 16``
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*4) x 8 x 8``
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. ``(ndf*8) x 4 x 4``
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input)

现在,与生成器一样,我们可以创建鉴别器、应用函数 weights_init并打印模型的结构。

# Create the Discriminator
netD = Discriminator(ngpu).to(device)

# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))

# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)

# Print the model
print(netD)

输出:

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
    (12): Sigmoid()
  )
)

损失函数和优化器
和D和G设置中,我们可以指定它们如何通过损失函数和优化器学习。我们将使用二元交叉熵损失 ( BCELoss ) 函数,该函数在 PyTorch 中定义为:
在这里插入图片描述

正如 DCGAN 论文中所指定的,两者都是 Adam 优化器,学习率为 0.0002,Beta1 = 0.5。为了跟踪生成器的学习进度,我们将生成一批固定的潜在向量,这些向量是从高斯分布(即固定噪声)中提取的。在训练循环中,我们会定期将这个fixed_noise输入到G,并且在迭代过程中,我们将看到图像是从噪声中形成的。

# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()

# Create batch of latent vectors that we will use to visualize
#  the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

训练

最后,现在我们已经定义了 GAN 框架的所有部分,我们可以训练它了。请注意,训练 GAN 在某种程度上是一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而对问题所在却几乎没有解释。在这里,我们将严格遵循Goodfellow 论文中的算法 1,同时遵守ganhacks中显示的一些最佳实践。也就是说,我们将“为真假图像构建不同的小批量”,并调整 G 的目标函数以最大化

# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):

        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu).view(-1)
        # Calculate loss on all-real batch
        errD_real = criterion(output, label)
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()

        ## Train with all-fake batch
        # Generate batch of latent vectors
        noise = torch.randn(b_size, nz, 1, 1, device=device)
        # Generate fake image batch with G
        fake = netG(noise)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach()).view(-1)
        # Calculate D's loss on the all-fake batch
        errD_fake = criterion(output, label)
        # Calculate the gradients for this batch, accumulated (summed) with previous gradients
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Compute error of D as sum over the fake and the real batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
        output = netD(fake).view(-1)
        # Calculate G's loss based on this output
        errG = criterion(output, label)
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()

        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())

        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

        iters += 1

输出:

Starting Training Loop...
[0/5][0/1583]   Loss_D: 1.4640  Loss_G: 6.9360  D(x): 0.7143    D(G(z)): 0.5877 / 0.0017
[0/5][50/1583]  Loss_D: 0.0174  Loss_G: 23.7368 D(x): 0.9881    D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.5983  Loss_G: 9.9471  D(x): 0.9715    D(G(z)): 0.3122 / 0.0003
[0/5][150/1583] Loss_D: 0.4940  Loss_G: 5.6772  D(x): 0.7028    D(G(z)): 0.0241 / 0.0091
[0/5][200/1583] Loss_D: 0.5931  Loss_G: 7.1186  D(x): 0.9423    D(G(z)): 0.3016 / 0.0018
[0/5][250/1583] Loss_D: 0.3846  Loss_G: 3.2697  D(x): 0.7663    D(G(z)): 0.0573 / 0.0739
[0/5][300/1583] Loss_D: 1.3306  Loss_G: 8.3204  D(x): 0.8768    D(G(z)): 0.6353 / 0.0009
[0/5][350/1583] Loss_D: 0.6451  Loss_G: 6.0499  D(x): 0.9025    D(G(z)): 0.3673 / 0.0060
[0/5][400/1583] Loss_D: 0.4211  Loss_G: 3.7316  D(x): 0.8407    D(G(z)): 0.1586 / 0.0392
[0/5][450/1583] Loss_D: 0.6569  Loss_G: 2.4818  D(x): 0.6437    D(G(z)): 0.0858 / 0.1129
[0/5][500/1583] Loss_D: 1.2208  Loss_G: 2.9943  D(x): 0.4179    D(G(z)): 0.0109 / 0.1133
[0/5][550/1583] Loss_D: 0.3400  Loss_G: 4.7669  D(x): 0.9135    D(G(z)): 0.1922 / 0.0145
[0/5][600/1583] Loss_D: 0.5756  Loss_G: 4.8500  D(x): 0.9189    D(G(z)): 0.3193 / 0.0187
[0/5][650/1583] Loss_D: 0.2470  Loss_G: 4.1606  D(x): 0.9460    D(G(z)): 0.1545 / 0.0250
[0/5][700/1583] Loss_D: 0.3887  Loss_G: 4.1884  D(x): 0.8518    D(G(z)): 0.1562 / 0.0297
[0/5][750/1583] Loss_D: 0.5353  Loss_G: 4.1742  D(x): 0.8034    D(G(z)): 0.1958 / 0.0302
[0/5][800/1583] Loss_D: 0.3213  Loss_G: 5.8919  D(x): 0.9076    D(G(z)): 0.1572 / 0.0065
[0/5][850/1583] Loss_D: 0.8850  Loss_G: 7.4333  D(x): 0.9258    D(G(z)): 0.4449 / 0.0017
[0/5][900/1583] Loss_D: 1.2624  Loss_G: 10.0392 D(x): 0.9896    D(G(z)): 0.6361 / 0.0002
[0/5][950/1583] Loss_D: 0.8802  Loss_G: 6.9221  D(x): 0.5527    D(G(z)): 0.0039 / 0.0045
[0/5][1000/1583]        Loss_D: 0.5799  Loss_G: 3.1800  D(x): 0.7062    D(G(z)): 0.0762 / 0.0884
[0/5][1050/1583]        Loss_D: 0.9647  Loss_G: 6.6894  D(x): 0.9429    D(G(z)): 0.5270 / 0.0035
[0/5][1100/1583]        Loss_D: 0.5624  Loss_G: 3.6715  D(x): 0.7944    D(G(z)): 0.2069 / 0.0445
[0/5][1150/1583]        Loss_D: 0.6205  Loss_G: 4.8995  D(x): 0.8634    D(G(z)): 0.3046 / 0.0169
[0/5][1200/1583]        Loss_D: 0.2569  Loss_G: 4.2945  D(x): 0.9455    D(G(z)): 0.1528 / 0.0255
[0/5][1250/1583]        Loss_D: 0.4921  Loss_G: 3.2500  D(x): 0.8152    D(G(z)): 0.1892 / 0.0753
[0/5][1300/1583]        Loss_D: 0.4068  Loss_G: 3.7702  D(x): 0.8153    D(G(z)): 0.1335 / 0.0472
[0/5][1350/1583]        Loss_D: 1.1704  Loss_G: 7.3408  D(x): 0.9443    D(G(z)): 0.5863 / 0.0022
[0/5][1400/1583]        Loss_D: 0.6111  Loss_G: 2.2676  D(x): 0.6714    D(G(z)): 0.0793 / 0.1510
[0/5][1450/1583]        Loss_D: 0.7817  Loss_G: 4.0744  D(x): 0.7915    D(G(z)): 0.3573 / 0.0242
[0/5][1500/1583]        Loss_D: 0.7177  Loss_G: 1.9253  D(x): 0.5770    D(G(z)): 0.0257 / 0.1909
[0/5][1550/1583]        Loss_D: 0.4518  Loss_G: 2.8314  D(x): 0.7991    D(G(z)): 0.1479 / 0.0885
[1/5][0/1583]   Loss_D: 0.4267  Loss_G: 4.5150  D(x): 0.8976    D(G(z)): 0.2401 / 0.0196
[1/5][50/1583]  Loss_D: 0.5106  Loss_G: 2.7800  D(x): 0.7073    D(G(z)): 0.0663 / 0.0932
[1/5][100/1583] Loss_D: 0.6300  Loss_G: 1.8648  D(x): 0.6557    D(G(z)): 0.0756 / 0.2118
[1/5][150/1583] Loss_D: 1.1727  Loss_G: 5.1536  D(x): 0.8397    D(G(z)): 0.5261 / 0.0125
[1/5][200/1583] Loss_D: 0.4675  Loss_G: 2.9615  D(x): 0.7645    D(G(z)): 0.1400 / 0.0780
[1/5][250/1583] Loss_D: 0.7938  Loss_G: 3.1614  D(x): 0.6958    D(G(z)): 0.2248 / 0.0678
[1/5][300/1583] Loss_D: 0.9869  Loss_G: 5.9243  D(x): 0.9619    D(G(z)): 0.5349 / 0.0063
[1/5][350/1583] Loss_D: 0.5178  Loss_G: 3.0236  D(x): 0.7795    D(G(z)): 0.1769 / 0.0700
[1/5][400/1583] Loss_D: 1.4509  Loss_G: 2.7187  D(x): 0.3278    D(G(z)): 0.0133 / 0.1273
[1/5][450/1583] Loss_D: 0.5530  Loss_G: 4.8110  D(x): 0.9151    D(G(z)): 0.3237 / 0.0160
[1/5][500/1583] Loss_D: 0.4621  Loss_G: 4.1158  D(x): 0.8720    D(G(z)): 0.2278 / 0.0293
[1/5][550/1583] Loss_D: 0.4987  Loss_G: 4.0199  D(x): 0.8533    D(G(z)): 0.2367 / 0.0287
[1/5][600/1583] Loss_D: 1.0630  Loss_G: 4.6502  D(x): 0.9145    D(G(z)): 0.5018 / 0.0218
[1/5][650/1583] Loss_D: 0.6081  Loss_G: 4.3172  D(x): 0.8670    D(G(z)): 0.3312 / 0.0221
[1/5][700/1583] Loss_D: 0.4703  Loss_G: 2.4900  D(x): 0.7538    D(G(z)): 0.1245 / 0.1188
[1/5][750/1583] Loss_D: 0.4827  Loss_G: 2.2941  D(x): 0.7372    D(G(z)): 0.1105 / 0.1300
[1/5][800/1583] Loss_D: 0.4013  Loss_G: 3.8850  D(x): 0.8895    D(G(z)): 0.2179 / 0.0324
[1/5][850/1583] Loss_D: 0.7245  Loss_G: 1.9088  D(x): 0.6100    D(G(z)): 0.0950 / 0.1898
[1/5][900/1583] Loss_D: 0.8372  Loss_G: 1.2346  D(x): 0.5232    D(G(z)): 0.0332 / 0.3633
[1/5][950/1583] Loss_D: 0.5561  Loss_G: 3.2048  D(x): 0.7660    D(G(z)): 0.2035 / 0.0594
[1/5][1000/1583]        Loss_D: 0.6859  Loss_G: 1.6347  D(x): 0.5764    D(G(z)): 0.0435 / 0.2540
[1/5][1050/1583]        Loss_D: 0.6785  Loss_G: 4.3244  D(x): 0.9066    D(G(z)): 0.3835 / 0.0203
[1/5][1100/1583]        Loss_D: 0.4835  Loss_G: 2.4080  D(x): 0.7428    D(G(z)): 0.1073 / 0.1147
[1/5][1150/1583]        Loss_D: 0.5507  Loss_G: 2.5400  D(x): 0.7857    D(G(z)): 0.2182 / 0.1092
[1/5][1200/1583]        Loss_D: 0.6054  Loss_G: 3.4802  D(x): 0.8263    D(G(z)): 0.2934 / 0.0441
[1/5][1250/1583]        Loss_D: 0.4788  Loss_G: 2.3533  D(x): 0.7872    D(G(z)): 0.1698 / 0.1327
[1/5][1300/1583]        Loss_D: 0.5314  Loss_G: 2.7018  D(x): 0.8273    D(G(z)): 0.2423 / 0.0921
[1/5][1350/1583]        Loss_D: 0.8579  Loss_G: 4.6214  D(x): 0.9623    D(G(z)): 0.5089 / 0.0159
[1/5][1400/1583]        Loss_D: 0.4919  Loss_G: 2.7656  D(x): 0.8122    D(G(z)): 0.2147 / 0.0864
[1/5][1450/1583]        Loss_D: 0.4461  Loss_G: 3.0576  D(x): 0.8042    D(G(z)): 0.1798 / 0.0619
[1/5][1500/1583]        Loss_D: 0.7182  Loss_G: 3.7270  D(x): 0.8553    D(G(z)): 0.3713 / 0.0382
[1/5][1550/1583]        Loss_D: 0.6378  Loss_G: 3.7489  D(x): 0.8757    D(G(z)): 0.3523 / 0.0317
[2/5][0/1583]   Loss_D: 0.3965  Loss_G: 2.6262  D(x): 0.7941    D(G(z)): 0.1247 / 0.0963
[2/5][50/1583]  Loss_D: 0.6504  Loss_G: 3.9890  D(x): 0.9267    D(G(z)): 0.3865 / 0.0275
[2/5][100/1583] Loss_D: 0.6523  Loss_G: 3.8724  D(x): 0.8707    D(G(z)): 0.3613 / 0.0299
[2/5][150/1583] Loss_D: 0.7685  Loss_G: 3.9059  D(x): 0.9361    D(G(z)): 0.4534 / 0.0278
[2/5][200/1583] Loss_D: 0.6587  Loss_G: 1.9218  D(x): 0.6469    D(G(z)): 0.1291 / 0.1888
[2/5][250/1583] Loss_D: 0.6971  Loss_G: 2.2256  D(x): 0.6208    D(G(z)): 0.1226 / 0.1465
[2/5][300/1583] Loss_D: 0.5797  Loss_G: 2.4846  D(x): 0.7762    D(G(z)): 0.2434 / 0.1098
[2/5][350/1583] Loss_D: 0.4674  Loss_G: 1.8800  D(x): 0.8045    D(G(z)): 0.1903 / 0.1877
[2/5][400/1583] Loss_D: 0.6462  Loss_G: 1.9510  D(x): 0.7018    D(G(z)): 0.1935 / 0.1792
[2/5][450/1583] Loss_D: 0.9817  Loss_G: 4.2519  D(x): 0.9421    D(G(z)): 0.5381 / 0.0233
[2/5][500/1583] Loss_D: 0.7721  Loss_G: 1.0928  D(x): 0.5402    D(G(z)): 0.0316 / 0.3927
[2/5][550/1583] Loss_D: 0.6037  Loss_G: 2.6914  D(x): 0.7719    D(G(z)): 0.2504 / 0.0896
[2/5][600/1583] Loss_D: 1.4213  Loss_G: 5.4727  D(x): 0.9408    D(G(z)): 0.6792 / 0.0064
[2/5][650/1583] Loss_D: 0.7246  Loss_G: 1.7030  D(x): 0.6716    D(G(z)): 0.2184 / 0.2246
[2/5][700/1583] Loss_D: 0.6642  Loss_G: 3.3809  D(x): 0.8554    D(G(z)): 0.3438 / 0.0591
[2/5][750/1583] Loss_D: 0.6649  Loss_G: 2.0197  D(x): 0.7169    D(G(z)): 0.2333 / 0.1565
[2/5][800/1583] Loss_D: 0.4594  Loss_G: 2.6623  D(x): 0.8150    D(G(z)): 0.1930 / 0.0944
[2/5][850/1583] Loss_D: 1.1957  Loss_G: 3.1871  D(x): 0.7790    D(G(z)): 0.5576 / 0.0568
[2/5][900/1583] Loss_D: 0.6657  Loss_G: 1.5311  D(x): 0.7092    D(G(z)): 0.2122 / 0.2558
[2/5][950/1583] Loss_D: 0.6795  Loss_G: 1.4149  D(x): 0.6134    D(G(z)): 0.1195 / 0.2937
[2/5][1000/1583]        Loss_D: 0.5995  Loss_G: 2.1744  D(x): 0.7325    D(G(z)): 0.2054 / 0.1484
[2/5][1050/1583]        Loss_D: 0.6706  Loss_G: 1.6705  D(x): 0.6425    D(G(z)): 0.1414 / 0.2310
[2/5][1100/1583]        Loss_D: 1.2840  Loss_G: 4.4620  D(x): 0.9736    D(G(z)): 0.6601 / 0.0225
[2/5][1150/1583]        Loss_D: 0.7568  Loss_G: 3.1238  D(x): 0.8153    D(G(z)): 0.3717 / 0.0581
[2/5][1200/1583]        Loss_D: 0.6331  Loss_G: 1.9048  D(x): 0.6799    D(G(z)): 0.1604 / 0.1814
[2/5][1250/1583]        Loss_D: 0.5802  Loss_G: 2.4358  D(x): 0.7561    D(G(z)): 0.2194 / 0.1095
[2/5][1300/1583]        Loss_D: 0.9613  Loss_G: 2.3290  D(x): 0.7463    D(G(z)): 0.3952 / 0.1349
[2/5][1350/1583]        Loss_D: 0.5367  Loss_G: 1.7398  D(x): 0.7580    D(G(z)): 0.1898 / 0.2216
[2/5][1400/1583]        Loss_D: 0.7762  Loss_G: 3.6246  D(x): 0.9006    D(G(z)): 0.4378 / 0.0364
[2/5][1450/1583]        Loss_D: 0.7183  Loss_G: 4.0442  D(x): 0.8602    D(G(z)): 0.3857 / 0.0254
[2/5][1500/1583]        Loss_D: 0.5416  Loss_G: 2.0642  D(x): 0.7393    D(G(z)): 0.1758 / 0.1532
[2/5][1550/1583]        Loss_D: 0.5295  Loss_G: 1.7855  D(x): 0.6768    D(G(z)): 0.0886 / 0.2154
[3/5][0/1583]   Loss_D: 0.8635  Loss_G: 1.7508  D(x): 0.4918    D(G(z)): 0.0280 / 0.2154
[3/5][50/1583]  Loss_D: 0.8697  Loss_G: 0.7859  D(x): 0.5216    D(G(z)): 0.1124 / 0.4941
[3/5][100/1583] Loss_D: 0.8607  Loss_G: 4.5255  D(x): 0.9197    D(G(z)): 0.4973 / 0.0157
[3/5][150/1583] Loss_D: 0.4805  Loss_G: 2.3071  D(x): 0.7743    D(G(z)): 0.1742 / 0.1291
[3/5][200/1583] Loss_D: 0.4925  Loss_G: 2.6018  D(x): 0.7907    D(G(z)): 0.1970 / 0.0948
[3/5][250/1583] Loss_D: 0.7870  Loss_G: 3.3529  D(x): 0.8408    D(G(z)): 0.4050 / 0.0469
[3/5][300/1583] Loss_D: 0.5479  Loss_G: 1.7376  D(x): 0.7216    D(G(z)): 0.1592 / 0.2227
[3/5][350/1583] Loss_D: 0.8117  Loss_G: 3.4145  D(x): 0.9076    D(G(z)): 0.4685 / 0.0437
[3/5][400/1583] Loss_D: 0.4210  Loss_G: 2.3880  D(x): 0.7543    D(G(z)): 0.1047 / 0.1217
[3/5][450/1583] Loss_D: 1.5745  Loss_G: 0.2366  D(x): 0.2747    D(G(z)): 0.0361 / 0.8096
[3/5][500/1583] Loss_D: 0.7196  Loss_G: 2.1319  D(x): 0.7332    D(G(z)): 0.2935 / 0.1403
[3/5][550/1583] Loss_D: 0.5697  Loss_G: 2.6649  D(x): 0.8816    D(G(z)): 0.3210 / 0.0917
[3/5][600/1583] Loss_D: 0.7779  Loss_G: 1.2727  D(x): 0.5540    D(G(z)): 0.0855 / 0.3412
[3/5][650/1583] Loss_D: 0.4090  Loss_G: 2.6893  D(x): 0.8334    D(G(z)): 0.1835 / 0.0855
[3/5][700/1583] Loss_D: 0.8108  Loss_G: 3.8991  D(x): 0.9241    D(G(z)): 0.4716 / 0.0281
[3/5][750/1583] Loss_D: 0.9907  Loss_G: 4.7885  D(x): 0.9111    D(G(z)): 0.5402 / 0.0123
[3/5][800/1583] Loss_D: 0.4725  Loss_G: 2.3347  D(x): 0.7577    D(G(z)): 0.1400 / 0.1222
[3/5][850/1583] Loss_D: 1.5580  Loss_G: 4.9586  D(x): 0.8954    D(G(z)): 0.7085 / 0.0132
[3/5][900/1583] Loss_D: 0.5785  Loss_G: 1.6395  D(x): 0.6581    D(G(z)): 0.1003 / 0.2411
[3/5][950/1583] Loss_D: 0.6592  Loss_G: 1.0890  D(x): 0.5893    D(G(z)): 0.0451 / 0.3809
[3/5][1000/1583]        Loss_D: 0.7280  Loss_G: 3.5368  D(x): 0.8898    D(G(z)): 0.4176 / 0.0409
[3/5][1050/1583]        Loss_D: 0.7088  Loss_G: 3.4301  D(x): 0.8558    D(G(z)): 0.3845 / 0.0457
[3/5][1100/1583]        Loss_D: 0.5651  Loss_G: 2.1150  D(x): 0.7602    D(G(z)): 0.2127 / 0.1532
[3/5][1150/1583]        Loss_D: 0.5412  Loss_G: 1.7790  D(x): 0.6602    D(G(z)): 0.0801 / 0.2088
[3/5][1200/1583]        Loss_D: 1.2277  Loss_G: 1.1464  D(x): 0.4864    D(G(z)): 0.2915 / 0.3665
[3/5][1250/1583]        Loss_D: 0.7148  Loss_G: 1.3957  D(x): 0.5948    D(G(z)): 0.1076 / 0.2876
[3/5][1300/1583]        Loss_D: 1.0675  Loss_G: 1.3018  D(x): 0.4056    D(G(z)): 0.0310 / 0.3355
[3/5][1350/1583]        Loss_D: 0.8064  Loss_G: 0.7482  D(x): 0.5846    D(G(z)): 0.1453 / 0.5147
[3/5][1400/1583]        Loss_D: 0.6032  Loss_G: 3.0601  D(x): 0.8474    D(G(z)): 0.3189 / 0.0590
[3/5][1450/1583]        Loss_D: 0.5329  Loss_G: 2.8172  D(x): 0.8234    D(G(z)): 0.2567 / 0.0795
[3/5][1500/1583]        Loss_D: 0.9292  Loss_G: 3.5544  D(x): 0.8686    D(G(z)): 0.4887 / 0.0410
[3/5][1550/1583]        Loss_D: 0.5929  Loss_G: 2.9118  D(x): 0.8614    D(G(z)): 0.3239 / 0.0702
[4/5][0/1583]   Loss_D: 0.5564  Loss_G: 2.7516  D(x): 0.8716    D(G(z)): 0.3145 / 0.0799
[4/5][50/1583]  Loss_D: 1.0485  Loss_G: 0.6751  D(x): 0.4332    D(G(z)): 0.0675 / 0.5568
[4/5][100/1583] Loss_D: 0.6753  Loss_G: 1.4046  D(x): 0.6028    D(G(z)): 0.0882 / 0.2901
[4/5][150/1583] Loss_D: 0.5946  Loss_G: 1.7618  D(x): 0.6862    D(G(z)): 0.1488 / 0.2016
[4/5][200/1583] Loss_D: 0.4866  Loss_G: 2.2638  D(x): 0.7628    D(G(z)): 0.1633 / 0.1321
[4/5][250/1583] Loss_D: 0.7493  Loss_G: 1.0999  D(x): 0.5541    D(G(z)): 0.0659 / 0.3787
[4/5][300/1583] Loss_D: 1.0886  Loss_G: 4.6532  D(x): 0.9370    D(G(z)): 0.5811 / 0.0149
[4/5][350/1583] Loss_D: 0.6106  Loss_G: 1.9212  D(x): 0.6594    D(G(z)): 0.1322 / 0.1825
[4/5][400/1583] Loss_D: 0.5226  Loss_G: 2.9611  D(x): 0.8178    D(G(z)): 0.2378 / 0.0731
[4/5][450/1583] Loss_D: 1.0068  Loss_G: 1.3267  D(x): 0.4310    D(G(z)): 0.0375 / 0.3179
[4/5][500/1583] Loss_D: 3.1088  Loss_G: 0.1269  D(x): 0.0706    D(G(z)): 0.0061 / 0.8897
[4/5][550/1583] Loss_D: 1.7889  Loss_G: 0.4800  D(x): 0.2175    D(G(z)): 0.0143 / 0.6479
[4/5][600/1583] Loss_D: 0.6732  Loss_G: 3.5685  D(x): 0.8775    D(G(z)): 0.3879 / 0.0362
[4/5][650/1583] Loss_D: 0.5169  Loss_G: 2.1943  D(x): 0.7222    D(G(z)): 0.1349 / 0.1416
[4/5][700/1583] Loss_D: 0.4567  Loss_G: 2.4442  D(x): 0.7666    D(G(z)): 0.1410 / 0.1204
[4/5][750/1583] Loss_D: 0.5972  Loss_G: 2.2992  D(x): 0.6286    D(G(z)): 0.0670 / 0.1283
[4/5][800/1583] Loss_D: 0.5461  Loss_G: 1.9777  D(x): 0.7013    D(G(z)): 0.1318 / 0.1795
[4/5][850/1583] Loss_D: 0.6317  Loss_G: 2.2345  D(x): 0.6962    D(G(z)): 0.1854 / 0.1385
[4/5][900/1583] Loss_D: 0.6034  Loss_G: 3.2300  D(x): 0.8781    D(G(z)): 0.3448 / 0.0517
[4/5][950/1583] Loss_D: 0.6371  Loss_G: 2.7755  D(x): 0.8595    D(G(z)): 0.3357 / 0.0826
[4/5][1000/1583]        Loss_D: 0.6077  Loss_G: 3.3958  D(x): 0.9026    D(G(z)): 0.3604 / 0.0458
[4/5][1050/1583]        Loss_D: 0.5057  Loss_G: 3.2545  D(x): 0.8705    D(G(z)): 0.2691 / 0.0546
[4/5][1100/1583]        Loss_D: 0.4552  Loss_G: 2.0632  D(x): 0.7887    D(G(z)): 0.1704 / 0.1524
[4/5][1150/1583]        Loss_D: 0.9933  Loss_G: 1.0264  D(x): 0.4507    D(G(z)): 0.0636 / 0.4182
[4/5][1200/1583]        Loss_D: 0.5037  Loss_G: 1.9940  D(x): 0.6967    D(G(z)): 0.0959 / 0.1698
[4/5][1250/1583]        Loss_D: 0.4760  Loss_G: 2.5973  D(x): 0.8192    D(G(z)): 0.2164 / 0.0945
[4/5][1300/1583]        Loss_D: 1.0137  Loss_G: 3.8782  D(x): 0.9330    D(G(z)): 0.5405 / 0.0309
[4/5][1350/1583]        Loss_D: 0.9084  Loss_G: 3.1406  D(x): 0.7540    D(G(z)): 0.3980 / 0.0648
[4/5][1400/1583]        Loss_D: 0.6724  Loss_G: 4.1269  D(x): 0.9536    D(G(z)): 0.4234 / 0.0236
[4/5][1450/1583]        Loss_D: 0.6452  Loss_G: 3.5163  D(x): 0.8730    D(G(z)): 0.3555 / 0.0412
[4/5][1500/1583]        Loss_D: 0.8843  Loss_G: 1.4950  D(x): 0.5314    D(G(z)): 0.1035 / 0.2835
[4/5][1550/1583]        Loss_D: 2.3345  Loss_G: 1.0675  D(x): 0.1448    D(G(z)): 0.0228 / 0.4177

结果

最后,让我们看看我们的表现如何。在这里,我们将看看三个不同的结果。首先,我们将看到 D 和 G 的损失在训练过程中如何变化。其次,我们将可视化每个时期的固定噪声批次上 G 的输出。第三,我们将查看来自 G 的一批真实数据和一批假数据。

损失与训练迭代

下面是 D & G 的损失与训练迭代的关系图。

plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

训练期间生成器和鉴别器损失

G 进展的可视化

记住我们如何在每个训练周期后将生成器的输出保存在固定噪声批次上。现在,我们可以用动画可视化 G 的训练过程。按播放按钮开始动画。

fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())

dcgan 面孔教程
在这里插入图片描述

真实图像与假图像

最后,让我们并排看一些真实图像和假图像。

# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

在这里插入图片描述

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
生成对抗网络(GAN)是深度学习领域的一项重要技术,利用GAN可以有效地生成复杂的样本数据,例如图像、音频等。在本文中将介绍如何用pytorch搭建GAN,并对其进行详细的解释。 GAN网络由生成器和判别器两部分组成。生成器接受随机噪声作为输入,通过反向传递训练来生成逼真的样本,而判别器则负责对输入样本进行判断,判断其是否是真实样本。两部分交替训练,并不断优化生成器和判别器的参数,最终可以得到生成生成逼真样本的能力。 搭建GAN需要先定义生成器和判别器的网络结构,其中生成器可以使用反卷积,而判别器可以使用卷积神经网络。此外,在搭建过程中还需要定义一些超参数,如学习率、训练轮数等。 在开始训练GAN之前,需要先准备好数据集,并对其进行预处理,例如归一化、降噪等。然后对生成器和判别器设置优化器,并开始训练。在训练过程中需要注意调整超参数以达到更好的效果。 最后,在训练结束后需要对GAN进行评估,可以通过计算生成样本与真实样本之间的差别来确定生成器的性能并对其进行改进。 总之,利用pytorch搭建入门GAN需要先定义网络结构和超参数,并使用适当的优化器进行训练,最终可以生成逼真的样本。同时,需要注意调整超参数以达到更好的效果,并对GAN进行评估和改进。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

码农呆呆

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值