概述:
GAN网络是一种对抗神经网络,由生成器和判别器两个网络组成。生成器的任务是生成与真实数据相似的虚假数据,而判别器的任务是判断输入数据是真实数据还是虚假数据。在训练过程中,生成器和判别器会相互对抗,生成器会努力生成与真实数据相似的虚假数据来欺骗判别器,而判别器则会尽可能提高自己的判断准确性。GAN网络的训练过程通常包括以下几个步骤:生成器生成虚假数据,并将其输入到判别器中进行判断。判别器对输入数据进行分类,并计算分类误差。根据分类误差,更新判别器的参数,以提高其分类准确率。生成器根据判别器的反馈信息,更新自身的参数,以生成更加逼真的虚假数据。GAN网络在图像生成、图像修复、语音识别等领域具有广泛的应用前景。
什么是 DCGAN?
DCGAN 是上述 GAN 的直接扩展,不同之处在于它分别在判别器和生成器中明确使用卷积层和卷积转置层。它首先由 Radford 等人描述。等人。在论文《使用深度卷积生成对抗网络进行无监督表示学习》中。判别器由跨步 卷积 层、批归一化 层和 LeakyReLU 激活组成。输入是 3x64x64 输入图像,输出是输入来自真实数据分布的标量概率。生成器由 卷积转置 层、批归一化层和 ReLU激活层组成。输入是一个潜在向量z,它是根据标准正态分布绘制的,输出是 3x64x64 RGB 图像。跨步的转换转置层允许将潜在向量转换为与图像形状相同的体积。在论文中,作者还给出了一些关于如何设置优化器、如何计算损失函数以及如何初始化模型权重的提示,所有这些都将在接下来的章节中进行解释。
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.use_deterministic_algorithms(True) # Needed for reproducible results
输出:
Random Seed: 999
输入
让我们为运行定义一些输入:
dataroot- 数据集文件夹根目录的路径。我们将在下一节中详细讨论数据集。
workers- 用于加载数据的工作线程数DataLoader。
batch_size- 训练中使用的批量大小。DCGAN 论文使用的批量大小为 128。
image_size- 用于训练的图像的空间大小。此实现默认为 64x64。如果需要其他尺寸,则必须改变D和G的结构。请参阅 此处了解更多详细信息。
nc- 输入图像中的颜色通道数。对于彩色图像,该值为 3。
nz- 潜在向量的长度。
ngf- 与生成器所承载的特征图的深度相关。
ndf- 设置通过鉴别器传播的特征图的深度。
num_epochs- 要运行的训练周期数。更长时间的训练可能会带来更好的结果,但也会花费更长的时间。
lr- 训练的学习率。正如 DCGAN 论文中所述,这个数字应该是 0.0002。
beta1- Adam 优化器的 beta1 超参数。正如论文中所述,这个数字应该是 0.5。
ngpu- 可用 GPU 的数量。如果为 0,代码将在 CPU 模式下运行。如果该数字大于 0,它将在该数量的 GPU 上运行。
# Root directory for dataset
dataroot = "data/celeba"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 128
# Spatial size of training images. All images will be resized to this
# size using a transformer.
image_size = 64
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparameter for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
数据
在本教程中,我们将使用Celeb-A Faces 数据集,该数据集可以在链接站点或Google Drive中下载。数据集将下载为名为img_align_celeba.zip. 下载后,创建一个名为 的目录celeba并将 zip 文件解压到该目录中。然后,dataroot将此笔记本的输入设置为celeba您刚刚创建的目录。最终的目录结构应该是:
/path/to/celeba
-> img_align_celeba
-> 188242.jpg
-> 173822.jpg
-> 284702.jpg
-> 537394.jpg
...
这是重要的一步,因为我们将使用ImageFolder 数据集类,这要求数据集根文件夹中有子目录。现在,我们可以创建数据集,创建数据加载器,设置要运行的设备,最后可视化一些训练数据。
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
输出:
<matplotlib.image.AxesImage object at 0x7f52efaa4550>
执行
设置好输入参数并准备好数据集后,我们现在可以开始实施了。我们将从权重初始化策略开始,然后详细讨论生成器、判别器、损失函数和训练循环。
权重初始化
在 DCGAN 论文中,作者指定所有模型权重均应从正态分布随机初始化mean=0, , stdev=0.02。该weights_init函数采用初始化的模型作为输入,并重新初始化所有卷积、卷积转置和批量归一化层以满足此标准。该函数在初始化后立即应用于模型。
# custom weights initialization called on ``netG`` and ``netD``
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
生成器
生成器G,旨在映射潜在空间向量(z) 到数据空间。由于我们的数据是图像,因此转换 z到数据空间意味着最终创建一个与训练图像大小相同的 RGB 图像(即 3x64x64)。在实践中,这是通过一系列跨步二维卷积转置层来完成的,每个层都与一个 2d 批归一化层和一个 relu 激活配对。生成器的输出通过 tanh 函数返回到输入数据范围[-1,1][ - 1 ,1 ]。值得注意的是,在卷积转置层之后存在批归一化函数,因为这是 DCGAN 论文的关键贡献。这些层有助于训练期间梯度的流动。DCGAN 论文中的生成器图像如下所示。
dcgan_生成器
nz请注意,我们在输入部分( 、ngf和 )中设置的输入如何nc影响代码中的生成器架构。nz是 z 输入向量的长度,ngf与通过生成器传播的特征图的大小相关,并且nc是输出图像中的通道数(对于 RGB 图像设置为 3)。下面是生成器的代码。
# Generator Code
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. ``(ngf*8) x 4 x 4``
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. ``(ngf*4) x 8 x 8``
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. ``(ngf*2) x 16 x 16``
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. ``(ngf) x 32 x 32``
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. ``(nc) x 64 x 64``
)
def forward(self, input):
return self.main(input)
现在,我们可以实例化生成器并应用该weights_init 函数。查看打印的模型以了解生成器对象的结构。
# Create the generator
netG = Generator(ngpu).to(device)
# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# Apply the ``weights_init`` function to randomly initialize all weights
# to ``mean=0``, ``stdev=0.02``.
netG.apply(weights_init)
# Print the model
print(netG)
输出:
Generator(
(main): Sequential(
(0): ConvTranspose2d(100, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(2): ReLU(inplace=True)
(3): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(5): ReLU(inplace=True)
(6): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(8): ReLU(inplace=True)
(9): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(10): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(11): ReLU(inplace=True)
(12): ConvTranspose2d(64, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(13): Tanh()
)
)
鉴别器
如前所述,判别器D,是一个二元分类网络,它将图像作为输入并输出输入图像是真实的(而不是假的)的标量概率。这里D获取 3x64x64 输入图像,通过一系列 Conv2d、BatchNorm2d 和 LeakyReLU 层对其进行处理,并通过 Sigmoid 激活函数输出最终概率。如果问题需要,可以使用更多层来扩展该架构,但使用跨步卷积、BatchNorm 和 LeakyReLU 具有重要意义。DCGAN 论文提到,使用跨步卷积而不是池化来下采样是一种很好的做法,因为它可以让网络学习自己的池化函数。批归一化和泄漏 ReLU 函数还可以促进健康的梯度流,这对于两者的学习过程都至关重要 G和D。
鉴别器代码
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is ``(nc) x 64 x 64``
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. ``(ndf) x 32 x 32``
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. ``(ndf*2) x 16 x 16``
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. ``(ndf*4) x 8 x 8``
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. ``(ndf*8) x 4 x 4``
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
现在,与生成器一样,我们可以创建鉴别器、应用函数 weights_init并打印模型的结构。
# Create the Discriminator
netD = Discriminator(ngpu).to(device)
# Handle multi-GPU if desired
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# Apply the ``weights_init`` function to randomly initialize all weights
# like this: ``to mean=0, stdev=0.2``.
netD.apply(weights_init)
# Print the model
print(netD)
输出:
Discriminator(
(main): Sequential(
(0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(1): LeakyReLU(negative_slope=0.2, inplace=True)
(2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(4): LeakyReLU(negative_slope=0.2, inplace=True)
(5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(6): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(7): LeakyReLU(negative_slope=0.2, inplace=True)
(8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
(9): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(10): LeakyReLU(negative_slope=0.2, inplace=True)
(11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
(12): Sigmoid()
)
)
损失函数和优化器
和D和G设置中,我们可以指定它们如何通过损失函数和优化器学习。我们将使用二元交叉熵损失 ( BCELoss ) 函数,该函数在 PyTorch 中定义为:
正如 DCGAN 论文中所指定的,两者都是 Adam 优化器,学习率为 0.0002,Beta1 = 0.5。为了跟踪生成器的学习进度,我们将生成一批固定的潜在向量,这些向量是从高斯分布(即固定噪声)中提取的。在训练循环中,我们会定期将这个fixed_noise输入到G,并且在迭代过程中,我们将看到图像是从噪声中形成的。
# Initialize the ``BCELoss`` function
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
# the progression of the generator
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# Establish convention for real and fake labels during training
real_label = 1.
fake_label = 0.
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
训练
最后,现在我们已经定义了 GAN 框架的所有部分,我们可以训练它了。请注意,训练 GAN 在某种程度上是一种艺术形式,因为不正确的超参数设置会导致模式崩溃,而对问题所在却几乎没有解释。在这里,我们将严格遵循Goodfellow 论文中的算法 1,同时遵守ganhacks中显示的一些最佳实践。也就是说,我们将“为真假图像构建不同的小批量”,并调整 G 的目标函数以最大化
# Training Loop
# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
############################
# (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
###########################
## Train with all-real batch
netD.zero_grad()
# Format batch
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# Forward pass real batch through D
output = netD(real_cpu).view(-1)
# Calculate loss on all-real batch
errD_real = criterion(output, label)
# Calculate gradients for D in backward pass
errD_real.backward()
D_x = output.mean().item()
## Train with all-fake batch
# Generate batch of latent vectors
noise = torch.randn(b_size, nz, 1, 1, device=device)
# Generate fake image batch with G
fake = netG(noise)
label.fill_(fake_label)
# Classify all fake batch with D
output = netD(fake.detach()).view(-1)
# Calculate D's loss on the all-fake batch
errD_fake = criterion(output, label)
# Calculate the gradients for this batch, accumulated (summed) with previous gradients
errD_fake.backward()
D_G_z1 = output.mean().item()
# Compute error of D as sum over the fake and the real batches
errD = errD_real + errD_fake
# Update D
optimizerD.step()
############################
# (2) Update G network: maximize log(D(G(z)))
###########################
netG.zero_grad()
label.fill_(real_label) # fake labels are real for generator cost
# Since we just updated D, perform another forward pass of all-fake batch through D
output = netD(fake).view(-1)
# Calculate G's loss based on this output
errG = criterion(output, label)
# Calculate gradients for G
errG.backward()
D_G_z2 = output.mean().item()
# Update G
optimizerG.step()
# Output training stats
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# Save Losses for plotting later
G_losses.append(errG.item())
D_losses.append(errD.item())
# Check how the generator is doing by saving G's output on fixed_noise
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
iters += 1
输出:
Starting Training Loop...
[0/5][0/1583] Loss_D: 1.4640 Loss_G: 6.9360 D(x): 0.7143 D(G(z)): 0.5877 / 0.0017
[0/5][50/1583] Loss_D: 0.0174 Loss_G: 23.7368 D(x): 0.9881 D(G(z)): 0.0000 / 0.0000
[0/5][100/1583] Loss_D: 0.5983 Loss_G: 9.9471 D(x): 0.9715 D(G(z)): 0.3122 / 0.0003
[0/5][150/1583] Loss_D: 0.4940 Loss_G: 5.6772 D(x): 0.7028 D(G(z)): 0.0241 / 0.0091
[0/5][200/1583] Loss_D: 0.5931 Loss_G: 7.1186 D(x): 0.9423 D(G(z)): 0.3016 / 0.0018
[0/5][250/1583] Loss_D: 0.3846 Loss_G: 3.2697 D(x): 0.7663 D(G(z)): 0.0573 / 0.0739
[0/5][300/1583] Loss_D: 1.3306 Loss_G: 8.3204 D(x): 0.8768 D(G(z)): 0.6353 / 0.0009
[0/5][350/1583] Loss_D: 0.6451 Loss_G: 6.0499 D(x): 0.9025 D(G(z)): 0.3673 / 0.0060
[0/5][400/1583] Loss_D: 0.4211 Loss_G: 3.7316 D(x): 0.8407 D(G(z)): 0.1586 / 0.0392
[0/5][450/1583] Loss_D: 0.6569 Loss_G: 2.4818 D(x): 0.6437 D(G(z)): 0.0858 / 0.1129
[0/5][500/1583] Loss_D: 1.2208 Loss_G: 2.9943 D(x): 0.4179 D(G(z)): 0.0109 / 0.1133
[0/5][550/1583] Loss_D: 0.3400 Loss_G: 4.7669 D(x): 0.9135 D(G(z)): 0.1922 / 0.0145
[0/5][600/1583] Loss_D: 0.5756 Loss_G: 4.8500 D(x): 0.9189 D(G(z)): 0.3193 / 0.0187
[0/5][650/1583] Loss_D: 0.2470 Loss_G: 4.1606 D(x): 0.9460 D(G(z)): 0.1545 / 0.0250
[0/5][700/1583] Loss_D: 0.3887 Loss_G: 4.1884 D(x): 0.8518 D(G(z)): 0.1562 / 0.0297
[0/5][750/1583] Loss_D: 0.5353 Loss_G: 4.1742 D(x): 0.8034 D(G(z)): 0.1958 / 0.0302
[0/5][800/1583] Loss_D: 0.3213 Loss_G: 5.8919 D(x): 0.9076 D(G(z)): 0.1572 / 0.0065
[0/5][850/1583] Loss_D: 0.8850 Loss_G: 7.4333 D(x): 0.9258 D(G(z)): 0.4449 / 0.0017
[0/5][900/1583] Loss_D: 1.2624 Loss_G: 10.0392 D(x): 0.9896 D(G(z)): 0.6361 / 0.0002
[0/5][950/1583] Loss_D: 0.8802 Loss_G: 6.9221 D(x): 0.5527 D(G(z)): 0.0039 / 0.0045
[0/5][1000/1583] Loss_D: 0.5799 Loss_G: 3.1800 D(x): 0.7062 D(G(z)): 0.0762 / 0.0884
[0/5][1050/1583] Loss_D: 0.9647 Loss_G: 6.6894 D(x): 0.9429 D(G(z)): 0.5270 / 0.0035
[0/5][1100/1583] Loss_D: 0.5624 Loss_G: 3.6715 D(x): 0.7944 D(G(z)): 0.2069 / 0.0445
[0/5][1150/1583] Loss_D: 0.6205 Loss_G: 4.8995 D(x): 0.8634 D(G(z)): 0.3046 / 0.0169
[0/5][1200/1583] Loss_D: 0.2569 Loss_G: 4.2945 D(x): 0.9455 D(G(z)): 0.1528 / 0.0255
[0/5][1250/1583] Loss_D: 0.4921 Loss_G: 3.2500 D(x): 0.8152 D(G(z)): 0.1892 / 0.0753
[0/5][1300/1583] Loss_D: 0.4068 Loss_G: 3.7702 D(x): 0.8153 D(G(z)): 0.1335 / 0.0472
[0/5][1350/1583] Loss_D: 1.1704 Loss_G: 7.3408 D(x): 0.9443 D(G(z)): 0.5863 / 0.0022
[0/5][1400/1583] Loss_D: 0.6111 Loss_G: 2.2676 D(x): 0.6714 D(G(z)): 0.0793 / 0.1510
[0/5][1450/1583] Loss_D: 0.7817 Loss_G: 4.0744 D(x): 0.7915 D(G(z)): 0.3573 / 0.0242
[0/5][1500/1583] Loss_D: 0.7177 Loss_G: 1.9253 D(x): 0.5770 D(G(z)): 0.0257 / 0.1909
[0/5][1550/1583] Loss_D: 0.4518 Loss_G: 2.8314 D(x): 0.7991 D(G(z)): 0.1479 / 0.0885
[1/5][0/1583] Loss_D: 0.4267 Loss_G: 4.5150 D(x): 0.8976 D(G(z)): 0.2401 / 0.0196
[1/5][50/1583] Loss_D: 0.5106 Loss_G: 2.7800 D(x): 0.7073 D(G(z)): 0.0663 / 0.0932
[1/5][100/1583] Loss_D: 0.6300 Loss_G: 1.8648 D(x): 0.6557 D(G(z)): 0.0756 / 0.2118
[1/5][150/1583] Loss_D: 1.1727 Loss_G: 5.1536 D(x): 0.8397 D(G(z)): 0.5261 / 0.0125
[1/5][200/1583] Loss_D: 0.4675 Loss_G: 2.9615 D(x): 0.7645 D(G(z)): 0.1400 / 0.0780
[1/5][250/1583] Loss_D: 0.7938 Loss_G: 3.1614 D(x): 0.6958 D(G(z)): 0.2248 / 0.0678
[1/5][300/1583] Loss_D: 0.9869 Loss_G: 5.9243 D(x): 0.9619 D(G(z)): 0.5349 / 0.0063
[1/5][350/1583] Loss_D: 0.5178 Loss_G: 3.0236 D(x): 0.7795 D(G(z)): 0.1769 / 0.0700
[1/5][400/1583] Loss_D: 1.4509 Loss_G: 2.7187 D(x): 0.3278 D(G(z)): 0.0133 / 0.1273
[1/5][450/1583] Loss_D: 0.5530 Loss_G: 4.8110 D(x): 0.9151 D(G(z)): 0.3237 / 0.0160
[1/5][500/1583] Loss_D: 0.4621 Loss_G: 4.1158 D(x): 0.8720 D(G(z)): 0.2278 / 0.0293
[1/5][550/1583] Loss_D: 0.4987 Loss_G: 4.0199 D(x): 0.8533 D(G(z)): 0.2367 / 0.0287
[1/5][600/1583] Loss_D: 1.0630 Loss_G: 4.6502 D(x): 0.9145 D(G(z)): 0.5018 / 0.0218
[1/5][650/1583] Loss_D: 0.6081 Loss_G: 4.3172 D(x): 0.8670 D(G(z)): 0.3312 / 0.0221
[1/5][700/1583] Loss_D: 0.4703 Loss_G: 2.4900 D(x): 0.7538 D(G(z)): 0.1245 / 0.1188
[1/5][750/1583] Loss_D: 0.4827 Loss_G: 2.2941 D(x): 0.7372 D(G(z)): 0.1105 / 0.1300
[1/5][800/1583] Loss_D: 0.4013 Loss_G: 3.8850 D(x): 0.8895 D(G(z)): 0.2179 / 0.0324
[1/5][850/1583] Loss_D: 0.7245 Loss_G: 1.9088 D(x): 0.6100 D(G(z)): 0.0950 / 0.1898
[1/5][900/1583] Loss_D: 0.8372 Loss_G: 1.2346 D(x): 0.5232 D(G(z)): 0.0332 / 0.3633
[1/5][950/1583] Loss_D: 0.5561 Loss_G: 3.2048 D(x): 0.7660 D(G(z)): 0.2035 / 0.0594
[1/5][1000/1583] Loss_D: 0.6859 Loss_G: 1.6347 D(x): 0.5764 D(G(z)): 0.0435 / 0.2540
[1/5][1050/1583] Loss_D: 0.6785 Loss_G: 4.3244 D(x): 0.9066 D(G(z)): 0.3835 / 0.0203
[1/5][1100/1583] Loss_D: 0.4835 Loss_G: 2.4080 D(x): 0.7428 D(G(z)): 0.1073 / 0.1147
[1/5][1150/1583] Loss_D: 0.5507 Loss_G: 2.5400 D(x): 0.7857 D(G(z)): 0.2182 / 0.1092
[1/5][1200/1583] Loss_D: 0.6054 Loss_G: 3.4802 D(x): 0.8263 D(G(z)): 0.2934 / 0.0441
[1/5][1250/1583] Loss_D: 0.4788 Loss_G: 2.3533 D(x): 0.7872 D(G(z)): 0.1698 / 0.1327
[1/5][1300/1583] Loss_D: 0.5314 Loss_G: 2.7018 D(x): 0.8273 D(G(z)): 0.2423 / 0.0921
[1/5][1350/1583] Loss_D: 0.8579 Loss_G: 4.6214 D(x): 0.9623 D(G(z)): 0.5089 / 0.0159
[1/5][1400/1583] Loss_D: 0.4919 Loss_G: 2.7656 D(x): 0.8122 D(G(z)): 0.2147 / 0.0864
[1/5][1450/1583] Loss_D: 0.4461 Loss_G: 3.0576 D(x): 0.8042 D(G(z)): 0.1798 / 0.0619
[1/5][1500/1583] Loss_D: 0.7182 Loss_G: 3.7270 D(x): 0.8553 D(G(z)): 0.3713 / 0.0382
[1/5][1550/1583] Loss_D: 0.6378 Loss_G: 3.7489 D(x): 0.8757 D(G(z)): 0.3523 / 0.0317
[2/5][0/1583] Loss_D: 0.3965 Loss_G: 2.6262 D(x): 0.7941 D(G(z)): 0.1247 / 0.0963
[2/5][50/1583] Loss_D: 0.6504 Loss_G: 3.9890 D(x): 0.9267 D(G(z)): 0.3865 / 0.0275
[2/5][100/1583] Loss_D: 0.6523 Loss_G: 3.8724 D(x): 0.8707 D(G(z)): 0.3613 / 0.0299
[2/5][150/1583] Loss_D: 0.7685 Loss_G: 3.9059 D(x): 0.9361 D(G(z)): 0.4534 / 0.0278
[2/5][200/1583] Loss_D: 0.6587 Loss_G: 1.9218 D(x): 0.6469 D(G(z)): 0.1291 / 0.1888
[2/5][250/1583] Loss_D: 0.6971 Loss_G: 2.2256 D(x): 0.6208 D(G(z)): 0.1226 / 0.1465
[2/5][300/1583] Loss_D: 0.5797 Loss_G: 2.4846 D(x): 0.7762 D(G(z)): 0.2434 / 0.1098
[2/5][350/1583] Loss_D: 0.4674 Loss_G: 1.8800 D(x): 0.8045 D(G(z)): 0.1903 / 0.1877
[2/5][400/1583] Loss_D: 0.6462 Loss_G: 1.9510 D(x): 0.7018 D(G(z)): 0.1935 / 0.1792
[2/5][450/1583] Loss_D: 0.9817 Loss_G: 4.2519 D(x): 0.9421 D(G(z)): 0.5381 / 0.0233
[2/5][500/1583] Loss_D: 0.7721 Loss_G: 1.0928 D(x): 0.5402 D(G(z)): 0.0316 / 0.3927
[2/5][550/1583] Loss_D: 0.6037 Loss_G: 2.6914 D(x): 0.7719 D(G(z)): 0.2504 / 0.0896
[2/5][600/1583] Loss_D: 1.4213 Loss_G: 5.4727 D(x): 0.9408 D(G(z)): 0.6792 / 0.0064
[2/5][650/1583] Loss_D: 0.7246 Loss_G: 1.7030 D(x): 0.6716 D(G(z)): 0.2184 / 0.2246
[2/5][700/1583] Loss_D: 0.6642 Loss_G: 3.3809 D(x): 0.8554 D(G(z)): 0.3438 / 0.0591
[2/5][750/1583] Loss_D: 0.6649 Loss_G: 2.0197 D(x): 0.7169 D(G(z)): 0.2333 / 0.1565
[2/5][800/1583] Loss_D: 0.4594 Loss_G: 2.6623 D(x): 0.8150 D(G(z)): 0.1930 / 0.0944
[2/5][850/1583] Loss_D: 1.1957 Loss_G: 3.1871 D(x): 0.7790 D(G(z)): 0.5576 / 0.0568
[2/5][900/1583] Loss_D: 0.6657 Loss_G: 1.5311 D(x): 0.7092 D(G(z)): 0.2122 / 0.2558
[2/5][950/1583] Loss_D: 0.6795 Loss_G: 1.4149 D(x): 0.6134 D(G(z)): 0.1195 / 0.2937
[2/5][1000/1583] Loss_D: 0.5995 Loss_G: 2.1744 D(x): 0.7325 D(G(z)): 0.2054 / 0.1484
[2/5][1050/1583] Loss_D: 0.6706 Loss_G: 1.6705 D(x): 0.6425 D(G(z)): 0.1414 / 0.2310
[2/5][1100/1583] Loss_D: 1.2840 Loss_G: 4.4620 D(x): 0.9736 D(G(z)): 0.6601 / 0.0225
[2/5][1150/1583] Loss_D: 0.7568 Loss_G: 3.1238 D(x): 0.8153 D(G(z)): 0.3717 / 0.0581
[2/5][1200/1583] Loss_D: 0.6331 Loss_G: 1.9048 D(x): 0.6799 D(G(z)): 0.1604 / 0.1814
[2/5][1250/1583] Loss_D: 0.5802 Loss_G: 2.4358 D(x): 0.7561 D(G(z)): 0.2194 / 0.1095
[2/5][1300/1583] Loss_D: 0.9613 Loss_G: 2.3290 D(x): 0.7463 D(G(z)): 0.3952 / 0.1349
[2/5][1350/1583] Loss_D: 0.5367 Loss_G: 1.7398 D(x): 0.7580 D(G(z)): 0.1898 / 0.2216
[2/5][1400/1583] Loss_D: 0.7762 Loss_G: 3.6246 D(x): 0.9006 D(G(z)): 0.4378 / 0.0364
[2/5][1450/1583] Loss_D: 0.7183 Loss_G: 4.0442 D(x): 0.8602 D(G(z)): 0.3857 / 0.0254
[2/5][1500/1583] Loss_D: 0.5416 Loss_G: 2.0642 D(x): 0.7393 D(G(z)): 0.1758 / 0.1532
[2/5][1550/1583] Loss_D: 0.5295 Loss_G: 1.7855 D(x): 0.6768 D(G(z)): 0.0886 / 0.2154
[3/5][0/1583] Loss_D: 0.8635 Loss_G: 1.7508 D(x): 0.4918 D(G(z)): 0.0280 / 0.2154
[3/5][50/1583] Loss_D: 0.8697 Loss_G: 0.7859 D(x): 0.5216 D(G(z)): 0.1124 / 0.4941
[3/5][100/1583] Loss_D: 0.8607 Loss_G: 4.5255 D(x): 0.9197 D(G(z)): 0.4973 / 0.0157
[3/5][150/1583] Loss_D: 0.4805 Loss_G: 2.3071 D(x): 0.7743 D(G(z)): 0.1742 / 0.1291
[3/5][200/1583] Loss_D: 0.4925 Loss_G: 2.6018 D(x): 0.7907 D(G(z)): 0.1970 / 0.0948
[3/5][250/1583] Loss_D: 0.7870 Loss_G: 3.3529 D(x): 0.8408 D(G(z)): 0.4050 / 0.0469
[3/5][300/1583] Loss_D: 0.5479 Loss_G: 1.7376 D(x): 0.7216 D(G(z)): 0.1592 / 0.2227
[3/5][350/1583] Loss_D: 0.8117 Loss_G: 3.4145 D(x): 0.9076 D(G(z)): 0.4685 / 0.0437
[3/5][400/1583] Loss_D: 0.4210 Loss_G: 2.3880 D(x): 0.7543 D(G(z)): 0.1047 / 0.1217
[3/5][450/1583] Loss_D: 1.5745 Loss_G: 0.2366 D(x): 0.2747 D(G(z)): 0.0361 / 0.8096
[3/5][500/1583] Loss_D: 0.7196 Loss_G: 2.1319 D(x): 0.7332 D(G(z)): 0.2935 / 0.1403
[3/5][550/1583] Loss_D: 0.5697 Loss_G: 2.6649 D(x): 0.8816 D(G(z)): 0.3210 / 0.0917
[3/5][600/1583] Loss_D: 0.7779 Loss_G: 1.2727 D(x): 0.5540 D(G(z)): 0.0855 / 0.3412
[3/5][650/1583] Loss_D: 0.4090 Loss_G: 2.6893 D(x): 0.8334 D(G(z)): 0.1835 / 0.0855
[3/5][700/1583] Loss_D: 0.8108 Loss_G: 3.8991 D(x): 0.9241 D(G(z)): 0.4716 / 0.0281
[3/5][750/1583] Loss_D: 0.9907 Loss_G: 4.7885 D(x): 0.9111 D(G(z)): 0.5402 / 0.0123
[3/5][800/1583] Loss_D: 0.4725 Loss_G: 2.3347 D(x): 0.7577 D(G(z)): 0.1400 / 0.1222
[3/5][850/1583] Loss_D: 1.5580 Loss_G: 4.9586 D(x): 0.8954 D(G(z)): 0.7085 / 0.0132
[3/5][900/1583] Loss_D: 0.5785 Loss_G: 1.6395 D(x): 0.6581 D(G(z)): 0.1003 / 0.2411
[3/5][950/1583] Loss_D: 0.6592 Loss_G: 1.0890 D(x): 0.5893 D(G(z)): 0.0451 / 0.3809
[3/5][1000/1583] Loss_D: 0.7280 Loss_G: 3.5368 D(x): 0.8898 D(G(z)): 0.4176 / 0.0409
[3/5][1050/1583] Loss_D: 0.7088 Loss_G: 3.4301 D(x): 0.8558 D(G(z)): 0.3845 / 0.0457
[3/5][1100/1583] Loss_D: 0.5651 Loss_G: 2.1150 D(x): 0.7602 D(G(z)): 0.2127 / 0.1532
[3/5][1150/1583] Loss_D: 0.5412 Loss_G: 1.7790 D(x): 0.6602 D(G(z)): 0.0801 / 0.2088
[3/5][1200/1583] Loss_D: 1.2277 Loss_G: 1.1464 D(x): 0.4864 D(G(z)): 0.2915 / 0.3665
[3/5][1250/1583] Loss_D: 0.7148 Loss_G: 1.3957 D(x): 0.5948 D(G(z)): 0.1076 / 0.2876
[3/5][1300/1583] Loss_D: 1.0675 Loss_G: 1.3018 D(x): 0.4056 D(G(z)): 0.0310 / 0.3355
[3/5][1350/1583] Loss_D: 0.8064 Loss_G: 0.7482 D(x): 0.5846 D(G(z)): 0.1453 / 0.5147
[3/5][1400/1583] Loss_D: 0.6032 Loss_G: 3.0601 D(x): 0.8474 D(G(z)): 0.3189 / 0.0590
[3/5][1450/1583] Loss_D: 0.5329 Loss_G: 2.8172 D(x): 0.8234 D(G(z)): 0.2567 / 0.0795
[3/5][1500/1583] Loss_D: 0.9292 Loss_G: 3.5544 D(x): 0.8686 D(G(z)): 0.4887 / 0.0410
[3/5][1550/1583] Loss_D: 0.5929 Loss_G: 2.9118 D(x): 0.8614 D(G(z)): 0.3239 / 0.0702
[4/5][0/1583] Loss_D: 0.5564 Loss_G: 2.7516 D(x): 0.8716 D(G(z)): 0.3145 / 0.0799
[4/5][50/1583] Loss_D: 1.0485 Loss_G: 0.6751 D(x): 0.4332 D(G(z)): 0.0675 / 0.5568
[4/5][100/1583] Loss_D: 0.6753 Loss_G: 1.4046 D(x): 0.6028 D(G(z)): 0.0882 / 0.2901
[4/5][150/1583] Loss_D: 0.5946 Loss_G: 1.7618 D(x): 0.6862 D(G(z)): 0.1488 / 0.2016
[4/5][200/1583] Loss_D: 0.4866 Loss_G: 2.2638 D(x): 0.7628 D(G(z)): 0.1633 / 0.1321
[4/5][250/1583] Loss_D: 0.7493 Loss_G: 1.0999 D(x): 0.5541 D(G(z)): 0.0659 / 0.3787
[4/5][300/1583] Loss_D: 1.0886 Loss_G: 4.6532 D(x): 0.9370 D(G(z)): 0.5811 / 0.0149
[4/5][350/1583] Loss_D: 0.6106 Loss_G: 1.9212 D(x): 0.6594 D(G(z)): 0.1322 / 0.1825
[4/5][400/1583] Loss_D: 0.5226 Loss_G: 2.9611 D(x): 0.8178 D(G(z)): 0.2378 / 0.0731
[4/5][450/1583] Loss_D: 1.0068 Loss_G: 1.3267 D(x): 0.4310 D(G(z)): 0.0375 / 0.3179
[4/5][500/1583] Loss_D: 3.1088 Loss_G: 0.1269 D(x): 0.0706 D(G(z)): 0.0061 / 0.8897
[4/5][550/1583] Loss_D: 1.7889 Loss_G: 0.4800 D(x): 0.2175 D(G(z)): 0.0143 / 0.6479
[4/5][600/1583] Loss_D: 0.6732 Loss_G: 3.5685 D(x): 0.8775 D(G(z)): 0.3879 / 0.0362
[4/5][650/1583] Loss_D: 0.5169 Loss_G: 2.1943 D(x): 0.7222 D(G(z)): 0.1349 / 0.1416
[4/5][700/1583] Loss_D: 0.4567 Loss_G: 2.4442 D(x): 0.7666 D(G(z)): 0.1410 / 0.1204
[4/5][750/1583] Loss_D: 0.5972 Loss_G: 2.2992 D(x): 0.6286 D(G(z)): 0.0670 / 0.1283
[4/5][800/1583] Loss_D: 0.5461 Loss_G: 1.9777 D(x): 0.7013 D(G(z)): 0.1318 / 0.1795
[4/5][850/1583] Loss_D: 0.6317 Loss_G: 2.2345 D(x): 0.6962 D(G(z)): 0.1854 / 0.1385
[4/5][900/1583] Loss_D: 0.6034 Loss_G: 3.2300 D(x): 0.8781 D(G(z)): 0.3448 / 0.0517
[4/5][950/1583] Loss_D: 0.6371 Loss_G: 2.7755 D(x): 0.8595 D(G(z)): 0.3357 / 0.0826
[4/5][1000/1583] Loss_D: 0.6077 Loss_G: 3.3958 D(x): 0.9026 D(G(z)): 0.3604 / 0.0458
[4/5][1050/1583] Loss_D: 0.5057 Loss_G: 3.2545 D(x): 0.8705 D(G(z)): 0.2691 / 0.0546
[4/5][1100/1583] Loss_D: 0.4552 Loss_G: 2.0632 D(x): 0.7887 D(G(z)): 0.1704 / 0.1524
[4/5][1150/1583] Loss_D: 0.9933 Loss_G: 1.0264 D(x): 0.4507 D(G(z)): 0.0636 / 0.4182
[4/5][1200/1583] Loss_D: 0.5037 Loss_G: 1.9940 D(x): 0.6967 D(G(z)): 0.0959 / 0.1698
[4/5][1250/1583] Loss_D: 0.4760 Loss_G: 2.5973 D(x): 0.8192 D(G(z)): 0.2164 / 0.0945
[4/5][1300/1583] Loss_D: 1.0137 Loss_G: 3.8782 D(x): 0.9330 D(G(z)): 0.5405 / 0.0309
[4/5][1350/1583] Loss_D: 0.9084 Loss_G: 3.1406 D(x): 0.7540 D(G(z)): 0.3980 / 0.0648
[4/5][1400/1583] Loss_D: 0.6724 Loss_G: 4.1269 D(x): 0.9536 D(G(z)): 0.4234 / 0.0236
[4/5][1450/1583] Loss_D: 0.6452 Loss_G: 3.5163 D(x): 0.8730 D(G(z)): 0.3555 / 0.0412
[4/5][1500/1583] Loss_D: 0.8843 Loss_G: 1.4950 D(x): 0.5314 D(G(z)): 0.1035 / 0.2835
[4/5][1550/1583] Loss_D: 2.3345 Loss_G: 1.0675 D(x): 0.1448 D(G(z)): 0.0228 / 0.4177
结果
最后,让我们看看我们的表现如何。在这里,我们将看看三个不同的结果。首先,我们将看到 D 和 G 的损失在训练过程中如何变化。其次,我们将可视化每个时期的固定噪声批次上 G 的输出。第三,我们将查看来自 G 的一批真实数据和一批假数据。
损失与训练迭代
下面是 D & G 的损失与训练迭代的关系图。
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
G 进展的可视化
记住我们如何在每个训练周期后将生成器的输出保存在固定噪声批次上。现在,我们可以用动画可视化 G 的训练过程。按播放按钮开始动画。
fig = plt.figure(figsize=(8,8))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
HTML(ani.to_jshtml())
真实图像与假图像
最后,让我们并排看一些真实图像和假图像。
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))
# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))
# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()