第G1周:生成对抗网络(GAN)入门

一、前期准备

1.定义超参数

代码知识点
这部分代码导入了所需的库和模块,包括
argparse用于解析命令行参数,
os用于处理文件路径,
numpy用于数值计算,
torchvision.transforms用于图像变换,
torchvision.utils中的save_image用于保存生成的图像,
torch.utils.data中的DataLoader用于加载数据,
torchvision中的datasets用于加载数据集,
torch.autograd中的Variable用于自动求导,
torch.nn用于构建神经网络模型,以及torch用于张量操作。

这些参数会影响生成对抗网络(GAN)的训练过程,具体如下:

n_epochs:这个参数决定了模型训练的总轮数。轮数越多,模型有更多机会学习数据中的模式,但也可能导致过拟合。

batch_size:批次大小影响模型每次更新时使用的数据量。较小的批次可能导致训练过程波动较大,但可能有助于模型逃离局部最小值;较大的批次则可能使训练更稳定,但需要更多的内存空间。

lr:学习率控制着模型权重更新的步长。学习率过大可能导致模型在最优解附近震荡甚至发散;学习率过小则可能导致模型收敛速度缓慢或陷入局部最小值。

b1和b2:这两个参数是Adam优化器的一部分,分别控制一阶矩(梯度的指数移动平均)和二阶矩(梯度平方的指数移动平均)的指数衰减率。它们影响模型更新的稳定性和收敛速度。

n_cpu:这个参数指定了用于数据加载的CPU数量,可以影响数据预处理和加载的速度,进而影响训练的效率。

latent_dim:随机向量的维度,它影响生成器生成图像的多样性和质量。维度过低可能导致生成图像缺乏多样性,而维度过高可能导致模型难以训练。

img_size:图像的大小直接影响模型的感受野和所需计算资源。图像尺寸越大,模型可能需要更多的计算资源和更长的训练时间。

channels:图像的通道数,对于彩色图像通常是3(RGB),对于灰度图像是1。通道数影响模型处理的信息量。

sample_interval:保存生成图像的间隔,这个参数决定了我们在训练过程中多久保存一次生成的图像,用于监控生成图像的质量。

cuda:是否使用GPU进行计算,使用GPU可以显著加速模型的训练过程,因为GPU在并行处理大量计算时更为高效。

import argparse
import os 
import numpy as np 
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch

os.makedirs("./images/",exist_ok=True)
os.makedirs("./save/",exist_ok=True)
os.makedirs("./datasets/mnist",exist_ok=True)

n_epochs=50
batch_size=64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500
img_shape =(channels,img_size,img_size)
img_area = np.prod(img_shape)
cuda = True if torch.cuda.is_available() else False
print(cuda)

输出
True

2.下载数据

mnist = datasets.MNIST(
    root = './datasets/', train=True, download=True, transform=transforms.Compose(
        [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
)

3.配置数据

dataloader =DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

二、定义模型

1.定义辨别器

代码知识点
这段代码定义了一个名为Discriminator的类,它继承自nn.Module。这个类是一个判别器模型,用于判断输入图像是否为真实图像。下面是对代码中每一行的详细解释:

  1. class Discriminator(nn.Module)::定义一个名为Discriminator的类,它继承自nn.Modulenn.Module是PyTorch中的一个基类,用于构建神经网络模型。

  2. def __init__(self)::定义类的构造函数,用于初始化模型的参数和层。

  3. super(Discriminator,self).__init__():调用父类nn.Module的构造函数,以确保正确地初始化模型。

  4. self.model = nn.Sequential(:创建一个nn.Sequential对象,它是一个容器,用于按顺序堆叠多个神经网络层。

  5. nn.Linear(img_area,512),:添加一个线性层,输入大小为img_area(图像区域的像素数),输出大小为512。这个层用于将输入图像展平并映射到一个新的特征空间。

  6. nn.LeakyReLU(0.2,inplace=True),:添加一个Leaky ReLU激活函数,其负斜率为0.2。inplace=True表示在原始数据上进行操作,以节省内存。

  7. nn.Linear(512,256),:添加一个线性层,输入大小为512,输出大小为256。这个层用于进一步将特征映射到更小的特征空间。

  8. nn.LeakyReLU(0.2,inplace=True),:再次添加一个Leaky ReLU激活函数,与之前的层相同。

  9. nn.Linear(256,1),:添加一个线性层,输入大小为256,输出大小为1。这个层用于将特征映射到一个标量值,用于表示输入图像的真实性。

  10. nn.Sigmoid(),:添加一个Sigmoid激活函数,将输出值限制在0到1之间。这可以解释为输入图像为真实图像的概率。

  11. ):结束nn.Sequential对象的创建。

  12. def forward(self, img)::定义模型的前向传播函数,用于计算输入图像的输出。

  13. img_flat = img.view(img.size(0),-1):将输入图像img展平为一个一维向量。img.size(0)表示批量大小,-1表示自动计算剩余维度的大小。

  14. validity = self.model(img_flat):将展平后的图像传递给之前定义的nn.Sequential模型,得到一个表示图像真实性的标量值。

  15. return validity:返回计算得到的图像真实性值。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area,512),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(512,256),
            nn.LeakyReLU(0.2,inplace=True),
            nn.Linear(256,1),
            nn.Sigmoid(),
        )
    def forward(self, img):
        img_flat = img.view(img.size(0),-1)
        validity = self.model(img_flat)
        return validity

2.定义生成器

代码知识点

  1. class Generator(nn.Module)::定义一个名为Generator的类,继承自PyTorch的nn.Module模块。
  2. def __init__ (self)::定义类的初始化方法。
  3. super(Generator,self).__init__():调用父类的初始化方法。
  4. def block(in_feat, out_feat, normalize=True)::定义一个名为block的内部函数,用于构建生成器网络的每个块。
  5. layers = [nn.Linear(in_feat, out_feat)]:创建一个线性层,输入特征数为in_feat,输出特征数为out_feat。
  6. if normalize::判断是否需要进行批量归一化。
  7. layers.append(nn.BatchNorm1d(out_feat, 0.8)):添加批量归一化层,输出特征数为out_feat,动量参数为0.8。
  8. layers.append(nn.LeakyReLU(0.2,inplace=True)):添加Leaky ReLU激活函数,负斜率为0.2,inplace参数设置为True表示直接修改输入数据。
  9. return layers:返回构建好的层列表。
  10. self.model = nn.Sequential(:创建一个顺序模型,用于堆叠生成器的各个层。
  11. *block(latent_dim,128,normalize=False),:调用block函数构建第一个块,输入特征数为latent_dim,输出特征数为128,不进行批量归一化。
  12. *block(128,256),:调用block函数构建第二个块,输入特征数为128,输出特征数为256。
  13. *block(256,512),:调用block函数构建第三个块,输入特征数为256,输出特征数为512。
  14. *block(512,1024),:调用block函数构建第四个块,输入特征数为512,输出特征数为1024。
  15. nn.Linear(1024,img_area),:添加一个线性层,输入特征数为1024,输出特征数为img_area。
  16. nn.Tanh():添加一个Tanh激活函数,将输出值映射到-1到1之间。
  17. ):结束顺序模型的定义。
  18. def forward(self,z)::定义前向传播方法,输入为z。
  19. imgs = self.model(z):将输入z传入模型,得到输出imgs。
  20. imgs = imgs.view(imgs.size(0),*img_shape):将输出imgs的形状调整为(batch_size,
    *img_shape)。
  21. return imgs:返回调整形状后的imgs。
class Generator(nn.Module):
    def __init__ (self):
        super(Generator,self).__init__()
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2,inplace=True))
            return layers
## prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
            *block(latent_dim,128,normalize=False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,img_area),
            nn.Tanh()
        )
## view():相当于numpy中的reshape,重新定义矩阵的形状
    def forward(self,z):
            imgs = self.model(z)
            imgs = imgs.view(imgs.size(0),*img_shape)
            return imgs

三、训练模型

1.创建实例

代码知识点

  1. generator = Generator():创建一个名为generator的生成器对象,该对象是Generator类的实
    例。
  2. discriminator = Discriminator():创建一个名为discriminator的判别器对象,该对象是Discriminator类的实例。
  3. criterion = torch.nn.BCELoss():创建一个名为criterion的损失函数对象,该对象是二元交叉熵损失(Binary Cross
    Entropy Loss)的实例。
  4. optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr, betas=(b1,b2)):创建一个名为optimizer_G的优化器对象,用于优化生成器网络的参数。该对象是Adam优化器的实例,学习率为lr,beta1和beta2分别为b1和b2。
  5. optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr, betas=(b1,b2)):创建一个名为optimizer_D的优化器对象,用于优化判别器网络的参数。该对象是Adam优化器的实例,学习率为lr,beta1和beta2分别为b1和b2。
  6. if torch.cuda.is_available()::判断当前设备是否支持CUDA加速。
  7. generator = generator.cuda():如果支持CUDA加速,将生成器对象转移到GPU上进行计算。
  8. discriminator = discriminator.cuda():如果支持CUDA加速,将判别器对象转移到GPU上进行计算。
  9. criterion = criterion.cuda():如果支持CUDA加速,将损失函数对象转移到GPU上进行计算。
generator = Generator()
discriminator = Discriminator()

criterion = torch.nn.BCELoss()

optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr, betas=(b1,b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr, betas=(b1,b2))

if torch.cuda.is_available():
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

2.训练模型

代码知识点

  1. for epoch in range(n_epochs)::开始一个循环,循环次数为n_epochs,每次循环代表一个训练周期。
  2. for i,(imgs, _) in enumerate(dataloader)::在每个训练周期内,遍历数据加载器(dataloader),获取图像数据和标签。
  3. imgs = imgs.view(imgs.size(0), -1):将图像数据调整为二维张量,其中第一维是批次大小,第二维是图像的展平表示。
  4. real_img = Variable(imgs).cuda():将真实的图像数据转换为变量并移动到GPU上进行计算。
  5. real_label = Variable(torch.ones(imgs.size(0),1)).cuda():创建一个全为1的标签向量,用于判别器的真实图像损失计算。
  6. fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda():创建一个全为0的标签向量,用于判别器的生成图像损失计算。
  7. real_out = discriminator(real_img):将真实图像输入判别器,得到判别器对真实图像的输出。
  8. loss_real_D = criterion(real_out,real_label):计算判别器对真实图像的损失。
  9. real_scores = real_out:将判别器对真实图像的输出保存为真实分数。
  10. z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda():生成随机噪声向量z。
  11. fake_img = generator(z).detach():将噪声向量z输入生成器,得到生成的图像,并将其从计算图中分离。
  12. fake_out = discriminator(fake_img):将生成的图像输入判别器,得到判别器对生成图像的输出。
  13. loss_fake_D = criterion(fake_out,fake_label):计算判别器对生成图像的损失。
  14. fake_scores = fake_out:将判别器对生成图像的输出保存为假分数。
  15. loss_D = loss_real_D + loss_fake_D:计算判别器的总损失,即真实图像损失和生成图像损失之和。
  16. optimizer_D.zero_grad():将判别器的梯度清零。
  17. loss_D.backward():计算判别器的损失关于参数的梯度。
  18. optimizer_D.step():更新判别器的参数。
  19. z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda():再次生成随机噪声向量z。
  20. fake_img = generator(z):将噪声向量z输入生成器,得到生成的图像。
  21. output = discriminator(fake_img):将生成的图像输入判别器,得到判别器对生成图像的输出。
  22. loss_G = criterion(output,real_label):计算生成器的损失。
  23. optimizer_G.zero_grad():将生成器的梯度清零。
  24. loss_G.backward():计算生成器的损失关于参数的梯度。
  25. optimizer_G.step():更新生成器的参数。
  26. if (i+1) % 300 == 0::每隔300个批次打印一次训练信息。
  27. print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]" % (epoch,n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())):打印当前训练周期、批次、判别器损失、生成器损失、判别器对真实图像的平均分数和判别器对生成图像的平均分数。
  28. batches_done = epoch * len(dataloader) + i:计算已完成的批次数。
  29. if batches_done % sample_interval == 0::每隔一定间隔保存生成的图像。
  30. save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow=5, normalize=True):保存前25个生成的图像到指定路径,并进行归一化处理。
for epoch in range(n_epochs):
    for i,(imgs, _) in enumerate(dataloader):
        imgs = imgs.view(imgs.size(0), -1)
        real_img = Variable(imgs).cuda()
        real_label = Variable(torch.ones(imgs.size(0),1)).cuda()
        fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()

        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out,real_label)
        real_scores = real_out
        z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
        fake_img = generator(z).detach()
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out,fake_label)
        fake_scores = fake_out

        loss_D = loss_real_D + loss_fake_D
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
        fake_img = generator(z)
        output = discriminator(fake_img)
        
        loss_G = criterion(output,real_label)
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()


        if (i+1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch,n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow=5, normalize=True)

输出
[Epoch 0/50] [Batch 299/938] [D loss: 1.486548] [G loss: 1.263540] [D real: 0.796954] [D fake: 0.705474]
[Epoch 0/50] [Batch 599/938] [D loss: 1.022943] [G loss: 0.799070] [D real: 0.472179] [D fake: 0.185987]
[Epoch 0/50] [Batch 899/938] [D loss: 1.116177] [G loss: 1.078084] [D real: 0.567511] [D fake: 0.387150]
[Epoch 1/50] [Batch 299/938] [D loss: 1.492023] [G loss: 2.803232] [D real: 0.912035] [D fake: 0.741440]
[Epoch 1/50] [Batch 599/938] [D loss: 1.088952] [G loss: 1.283060] [D real: 0.638842] [D fake: 0.424129]
[Epoch 1/50] [Batch 899/938] [D loss: 1.045345] [G loss: 1.928071] [D real: 0.777452] [D fake: 0.519182]
[Epoch 2/50] [Batch 299/938] [D loss: 0.924773] [G loss: 1.558642] [D real: 0.691493] [D fake: 0.375394]
[Epoch 2/50] [Batch 599/938] [D loss: 0.859401] [G loss: 1.168849] [D real: 0.610856] [D fake: 0.238899]
[Epoch 2/50] [Batch 899/938] [D loss: 0.994440] [G loss: 0.886017] [D real: 0.506867] [D fake: 0.113889]
[Epoch 3/50] [Batch 299/938] [D loss: 1.111854] [G loss: 0.959417] [D real: 0.470583] [D fake: 0.124310]
[Epoch 3/50] [Batch 599/938] [D loss: 0.795403] [G loss: 1.288844] [D real: 0.660531] [D fake: 0.231165]
[Epoch 3/50] [Batch 899/938] [D loss: 0.916148] [G loss: 2.614962] [D real: 0.825633] [D fake: 0.481356]
[Epoch 4/50] [Batch 299/938] [D loss: 0.884143] [G loss: 2.002452] [D real: 0.764259] [D fake: 0.402012]
[Epoch 4/50] [Batch 599/938] [D loss: 0.741833] [G loss: 2.512850] [D real: 0.797984] [D fake: 0.330773]
[Epoch 4/50] [Batch 899/938] [D loss: 0.778115] [G loss: 1.621227] [D real: 0.742208] [D fake: 0.324576]
[Epoch 5/50] [Batch 299/938] [D loss: 0.764483] [G loss: 1.628026] [D real: 0.764330] [D fake: 0.341625]
[Epoch 5/50] [Batch 599/938] [D loss: 0.863362] [G loss: 2.213945] [D real: 0.813163] [D fake: 0.430946]
[Epoch 5/50] [Batch 899/938] [D loss: 0.816530] [G loss: 1.191403] [D real: 0.599072] [D fake: 0.124816]
[Epoch 6/50] [Batch 299/938] [D loss: 0.893659] [G loss: 2.667447] [D real: 0.821634] [D fake: 0.457403]
[Epoch 6/50] [Batch 599/938] [D loss: 0.937223] [G loss: 0.625343] [D real: 0.553249] [D fake: 0.183966]
[Epoch 6/50] [Batch 899/938] [D loss: 0.858882] [G loss: 1.268875] [D real: 0.662991] [D fake: 0.270400]
[Epoch 7/50] [Batch 299/938] [D loss: 0.920097] [G loss: 1.110757] [D real: 0.662119] [D fake: 0.310955]
[Epoch 7/50] [Batch 599/938] [D loss: 0.848626] [G loss: 1.869307] [D real: 0.794307] [D fake: 0.407518]
[Epoch 7/50] [Batch 899/938] [D loss: 0.894683] [G loss: 1.515777] [D real: 0.668894] [D fake: 0.315810]
[Epoch 8/50] [Batch 299/938] [D loss: 0.945508] [G loss: 1.450796] [D real: 0.684737] [D fake: 0.359553]
[Epoch 8/50] [Batch 599/938] [D loss: 0.893079] [G loss: 1.032745] [D real: 0.595712] [D fake: 0.202011]
[Epoch 8/50] [Batch 899/938] [D loss: 1.158379] [G loss: 0.836311] [D real: 0.476757] [D fake: 0.137825]
[Epoch 9/50] [Batch 299/938] [D loss: 1.194173] [G loss: 1.441912] [D real: 0.715668] [D fake: 0.492410]
[Epoch 9/50] [Batch 599/938] [D loss: 0.977941] [G loss: 1.354820] [D real: 0.690165] [D fake: 0.390519]
[Epoch 9/50] [Batch 899/938] [D loss: 1.047443] [G loss: 0.890079] [D real: 0.501754] [D fake: 0.143151]
[Epoch 10/50] [Batch 299/938] [D loss: 0.928828] [G loss: 1.659332] [D real: 0.735775] [D fake: 0.406559]
[Epoch 10/50] [Batch 599/938] [D loss: 0.805291] [G loss: 1.204942] [D real: 0.677094] [D fake: 0.243433]
[Epoch 10/50] [Batch 899/938] [D loss: 1.082386] [G loss: 0.736025] [D real: 0.518290] [D fake: 0.203062]
[Epoch 11/50] [Batch 299/938] [D loss: 1.207286] [G loss: 2.173785] [D real: 0.823033] [D fake: 0.567325]
[Epoch 11/50] [Batch 599/938] [D loss: 0.864652] [G loss: 1.226438] [D real: 0.670658] [D fake: 0.292929]
[Epoch 11/50] [Batch 899/938] [D loss: 1.067761] [G loss: 0.816991] [D real: 0.557894] [D fake: 0.265028]
[Epoch 12/50] [Batch 299/938] [D loss: 1.094701] [G loss: 0.828730] [D real: 0.478709] [D fake: 0.152698]
[Epoch 12/50] [Batch 599/938] [D loss: 0.923310] [G loss: 1.575744] [D real: 0.673793] [D fake: 0.341787]
[Epoch 12/50] [Batch 899/938] [D loss: 0.882771] [G loss: 1.570408] [D real: 0.721132] [D fake: 0.348645]
[Epoch 13/50] [Batch 299/938] [D loss: 1.213910] [G loss: 1.739210] [D real: 0.816769] [D fake: 0.581116]
[Epoch 13/50] [Batch 599/938] [D loss: 0.986962] [G loss: 2.176203] [D real: 0.801824] [D fake: 0.485661]
[Epoch 13/50] [Batch 899/938] [D loss: 1.123384] [G loss: 1.887434] [D real: 0.788300] [D fake: 0.526488]
[Epoch 14/50] [Batch 299/938] [D loss: 0.993780] [G loss: 1.074014] [D real: 0.621304] [D fake: 0.302661]
[Epoch 14/50] [Batch 599/938] [D loss: 1.139293] [G loss: 0.902772] [D real: 0.486842] [D fake: 0.136740]
[Epoch 14/50] [Batch 899/938] [D loss: 1.039376] [G loss: 0.765984] [D real: 0.579033] [D fake: 0.273950]
[Epoch 15/50] [Batch 299/938] [D loss: 0.957737] [G loss: 1.730967] [D real: 0.735324] [D fake: 0.417273]
[Epoch 15/50] [Batch 599/938] [D loss: 1.133914] [G loss: 1.851659] [D real: 0.768372] [D fake: 0.543797]
[Epoch 15/50] [Batch 899/938] [D loss: 0.953224] [G loss: 1.377125] [D real: 0.664915] [D fake: 0.350922]
[Epoch 16/50] [Batch 299/938] [D loss: 0.970558] [G loss: 1.432826] [D real: 0.672650] [D fake: 0.376319]
[Epoch 16/50] [Batch 599/938] [D loss: 0.917780] [G loss: 1.423335] [D real: 0.716623] [D fake: 0.372080]
[Epoch 16/50] [Batch 899/938] [D loss: 1.034982] [G loss: 1.312837] [D real: 0.676499] [D fake: 0.388286]
[Epoch 17/50] [Batch 299/938] [D loss: 1.283185] [G loss: 2.052758] [D real: 0.843837] [D fake: 0.627986]
[Epoch 17/50] [Batch 599/938] [D loss: 1.014387] [G loss: 1.197765] [D real: 0.630091] [D fake: 0.338972]
[Epoch 17/50] [Batch 899/938] [D loss: 0.953568] [G loss: 1.231606] [D real: 0.657918] [D fake: 0.351649]
[Epoch 18/50] [Batch 299/938] [D loss: 1.032209] [G loss: 1.185335] [D real: 0.588554] [D fake: 0.312889]
[Epoch 18/50] [Batch 599/938] [D loss: 0.938251] [G loss: 1.347066] [D real: 0.614758] [D fake: 0.287012]
[Epoch 18/50] [Batch 899/938] [D loss: 1.098345] [G loss: 1.045380] [D real: 0.670897] [D fake: 0.438555]
[Epoch 19/50] [Batch 299/938] [D loss: 1.041932] [G loss: 1.125989] [D real: 0.623168] [D fake: 0.349922]
[Epoch 19/50] [Batch 599/938] [D loss: 1.049016] [G loss: 0.997303] [D real: 0.560449] [D fake: 0.289190]
[Epoch 19/50] [Batch 899/938] [D loss: 1.109702] [G loss: 0.830619] [D real: 0.554481] [D fake: 0.302520]
[Epoch 20/50] [Batch 299/938] [D loss: 1.091989] [G loss: 1.078110] [D real: 0.579322] [D fake: 0.306465]
[Epoch 20/50] [Batch 599/938] [D loss: 1.075116] [G loss: 1.026978] [D real: 0.646824] [D fake: 0.379286]
[Epoch 20/50] [Batch 899/938] [D loss: 1.012776] [G loss: 1.252582] [D real: 0.650999] [D fake: 0.355130]
[Epoch 21/50] [Batch 299/938] [D loss: 1.019206] [G loss: 1.080652] [D real: 0.571173] [D fake: 0.282273]
[Epoch 21/50] [Batch 599/938] [D loss: 1.092633] [G loss: 0.927598] [D real: 0.527832] [D fake: 0.257436]
[Epoch 21/50] [Batch 899/938] [D loss: 1.221342] [G loss: 1.074707] [D real: 0.572400] [D fake: 0.380784]
[Epoch 22/50] [Batch 299/938] [D loss: 1.065374] [G loss: 0.797963] [D real: 0.484218] [D fake: 0.195965]
[Epoch 22/50] [Batch 599/938] [D loss: 1.105837] [G loss: 0.691518] [D real: 0.514260] [D fake: 0.276405]
[Epoch 22/50] [Batch 899/938] [D loss: 0.878742] [G loss: 1.166016] [D real: 0.724698] [D fake: 0.377328]
[Epoch 23/50] [Batch 299/938] [D loss: 1.011820] [G loss: 1.244210] [D real: 0.713064] [D fake: 0.429894]
[Epoch 23/50] [Batch 599/938] [D loss: 1.073738] [G loss: 1.664922] [D real: 0.719866] [D fake: 0.461699]
[Epoch 23/50] [Batch 899/938] [D loss: 1.009477] [G loss: 1.276478] [D real: 0.616435] [D fake: 0.336026]
[Epoch 24/50] [Batch 299/938] [D loss: 1.020366] [G loss: 1.251225] [D real: 0.648893] [D fake: 0.388956]
[Epoch 24/50] [Batch 599/938] [D loss: 1.101402] [G loss: 2.135847] [D real: 0.786404] [D fake: 0.536431]
[Epoch 24/50] [Batch 899/938] [D loss: 1.202906] [G loss: 0.715893] [D real: 0.452851] [D fake: 0.163171]
[Epoch 25/50] [Batch 299/938] [D loss: 1.140191] [G loss: 0.871201] [D real: 0.527046] [D fake: 0.297866]
[Epoch 25/50] [Batch 599/938] [D loss: 1.117683] [G loss: 1.204782] [D real: 0.623783] [D fake: 0.375783]
[Epoch 25/50] [Batch 899/938] [D loss: 1.122934] [G loss: 1.279258] [D real: 0.694599] [D fake: 0.453740]
[Epoch 26/50] [Batch 299/938] [D loss: 1.059794] [G loss: 0.850840] [D real: 0.512065] [D fake: 0.202767]
[Epoch 26/50] [Batch 599/938] [D loss: 1.072658] [G loss: 1.585017] [D real: 0.740615] [D fake: 0.476374]
[Epoch 26/50] [Batch 899/938] [D loss: 1.177186] [G loss: 0.832109] [D real: 0.472754] [D fake: 0.196632]
[Epoch 27/50] [Batch 299/938] [D loss: 0.986746] [G loss: 1.037914] [D real: 0.552875] [D fake: 0.219123]
[Epoch 27/50] [Batch 599/938] [D loss: 1.001620] [G loss: 1.654807] [D real: 0.745383] [D fake: 0.430741]
[Epoch 27/50] [Batch 899/938] [D loss: 1.007908] [G loss: 0.997429] [D real: 0.587789] [D fake: 0.299459]
[Epoch 28/50] [Batch 299/938] [D loss: 1.002784] [G loss: 1.241717] [D real: 0.709475] [D fake: 0.420782]
[Epoch 28/50] [Batch 599/938] [D loss: 1.058915] [G loss: 1.121608] [D real: 0.660302] [D fake: 0.398084]
[Epoch 28/50] [Batch 899/938] [D loss: 0.949169] [G loss: 0.922655] [D real: 0.599296] [D fake: 0.259329]
[Epoch 29/50] [Batch 299/938] [D loss: 1.099054] [G loss: 0.816260] [D real: 0.549399] [D fake: 0.305472]
[Epoch 29/50] [Batch 599/938] [D loss: 1.008659] [G loss: 1.256130] [D real: 0.737606] [D fake: 0.429921]
[Epoch 29/50] [Batch 899/938] [D loss: 1.039975] [G loss: 1.456424] [D real: 0.730297] [D fake: 0.425789]
[Epoch 30/50] [Batch 299/938] [D loss: 1.063697] [G loss: 1.032279] [D real: 0.618258] [D fake: 0.364028]
[Epoch 30/50] [Batch 599/938] [D loss: 1.114642] [G loss: 1.696752] [D real: 0.735125] [D fake: 0.477111]
[Epoch 30/50] [Batch 899/938] [D loss: 1.077305] [G loss: 1.105972] [D real: 0.530655] [D fake: 0.231420]
[Epoch 31/50] [Batch 299/938] [D loss: 1.011994] [G loss: 1.358363] [D real: 0.667097] [D fake: 0.406712]
[Epoch 31/50] [Batch 599/938] [D loss: 0.991479] [G loss: 1.544970] [D real: 0.680855] [D fake: 0.399283]
[Epoch 31/50] [Batch 899/938] [D loss: 1.009350] [G loss: 1.655224] [D real: 0.741437] [D fake: 0.454083]
[Epoch 32/50] [Batch 299/938] [D loss: 1.009774] [G loss: 1.154610] [D real: 0.697059] [D fake: 0.417741]
[Epoch 32/50] [Batch 599/938] [D loss: 1.131516] [G loss: 1.604385] [D real: 0.799639] [D fake: 0.521479]
[Epoch 32/50] [Batch 899/938] [D loss: 0.971643] [G loss: 1.751588] [D real: 0.769394] [D fake: 0.447989]
[Epoch 33/50] [Batch 299/938] [D loss: 0.937302] [G loss: 1.137668] [D real: 0.667549] [D fake: 0.352494]
[Epoch 33/50] [Batch 599/938] [D loss: 1.142665] [G loss: 1.605705] [D real: 0.736860] [D fake: 0.494053]
[Epoch 33/50] [Batch 899/938] [D loss: 0.975630] [G loss: 1.081744] [D real: 0.617061] [D fake: 0.311667]
[Epoch 34/50] [Batch 299/938] [D loss: 1.004504] [G loss: 1.199623] [D real: 0.584345] [D fake: 0.250685]
[Epoch 34/50] [Batch 599/938] [D loss: 1.111552] [G loss: 1.399492] [D real: 0.683633] [D fake: 0.437264]
[Epoch 34/50] [Batch 899/938] [D loss: 1.057444] [G loss: 1.816881] [D real: 0.718745] [D fake: 0.452988]
[Epoch 35/50] [Batch 299/938] [D loss: 1.072661] [G loss: 1.064296] [D real: 0.542904] [D fake: 0.252511]
[Epoch 35/50] [Batch 599/938] [D loss: 1.006217] [G loss: 1.576730] [D real: 0.748175] [D fake: 0.426752]
[Epoch 35/50] [Batch 899/938] [D loss: 1.006518] [G loss: 1.166273] [D real: 0.663842] [D fake: 0.369748]
[Epoch 36/50] [Batch 299/938] [D loss: 1.074173] [G loss: 0.836063] [D real: 0.564981] [D fake: 0.312169]
[Epoch 36/50] [Batch 599/938] [D loss: 1.115033] [G loss: 1.369953] [D real: 0.592239] [D fake: 0.329859]
[Epoch 36/50] [Batch 899/938] [D loss: 1.022781] [G loss: 1.326988] [D real: 0.578367] [D fake: 0.262427]
[Epoch 37/50] [Batch 299/938] [D loss: 0.978686] [G loss: 1.245194] [D real: 0.673857] [D fake: 0.346784]
[Epoch 37/50] [Batch 599/938] [D loss: 1.157297] [G loss: 1.436168] [D real: 0.660286] [D fake: 0.434538]
[Epoch 37/50] [Batch 899/938] [D loss: 0.866226] [G loss: 1.513450] [D real: 0.741742] [D fake: 0.375720]
[Epoch 38/50] [Batch 299/938] [D loss: 1.042357] [G loss: 1.409830] [D real: 0.776054] [D fake: 0.474026]
[Epoch 38/50] [Batch 599/938] [D loss: 0.934915] [G loss: 1.663973] [D real: 0.702456] [D fake: 0.368171]
[Epoch 38/50] [Batch 899/938] [D loss: 1.059174] [G loss: 1.293531] [D real: 0.617721] [D fake: 0.317426]
[Epoch 39/50] [Batch 299/938] [D loss: 0.964760] [G loss: 1.176818] [D real: 0.639019] [D fake: 0.328120]
[Epoch 39/50] [Batch 599/938] [D loss: 1.327827] [G loss: 2.243500] [D real: 0.822750] [D fake: 0.608486]
[Epoch 39/50] [Batch 899/938] [D loss: 0.964594] [G loss: 1.032648] [D real: 0.629228] [D fake: 0.300384]
[Epoch 40/50] [Batch 299/938] [D loss: 1.098591] [G loss: 0.879305] [D real: 0.577749] [D fake: 0.312382]
[Epoch 40/50] [Batch 599/938] [D loss: 0.914958] [G loss: 1.446056] [D real: 0.676897] [D fake: 0.326900]
[Epoch 40/50] [Batch 899/938] [D loss: 0.959015] [G loss: 1.476370] [D real: 0.706569] [D fake: 0.380073]
[Epoch 41/50] [Batch 299/938] [D loss: 1.369267] [G loss: 2.306901] [D real: 0.841119] [D fake: 0.601469]
[Epoch 41/50] [Batch 599/938] [D loss: 0.934483] [G loss: 1.628165] [D real: 0.624097] [D fake: 0.244898]
[Epoch 41/50] [Batch 899/938] [D loss: 1.014262] [G loss: 1.479493] [D real: 0.729656] [D fake: 0.417756]
[Epoch 42/50] [Batch 299/938] [D loss: 1.011384] [G loss: 1.144827] [D real: 0.611724] [D fake: 0.298554]
[Epoch 42/50] [Batch 599/938] [D loss: 0.968165] [G loss: 1.469469] [D real: 0.697256] [D fake: 0.376759]
[Epoch 42/50] [Batch 899/938] [D loss: 0.896047] [G loss: 1.262222] [D real: 0.693943] [D fake: 0.340994]
[Epoch 43/50] [Batch 299/938] [D loss: 1.098878] [G loss: 1.494534] [D real: 0.622190] [D fake: 0.352108]
[Epoch 43/50] [Batch 599/938] [D loss: 1.005295] [G loss: 1.588117] [D real: 0.714370] [D fake: 0.395605]
[Epoch 43/50] [Batch 899/938] [D loss: 0.841020] [G loss: 1.947297] [D real: 0.727799] [D fake: 0.331271]
[Epoch 44/50] [Batch 299/938] [D loss: 0.967317] [G loss: 1.357352] [D real: 0.663303] [D fake: 0.335372]
[Epoch 44/50] [Batch 599/938] [D loss: 0.907341] [G loss: 1.253705] [D real: 0.699086] [D fake: 0.349763]
[Epoch 44/50] [Batch 899/938] [D loss: 0.898501] [G loss: 1.435176] [D real: 0.682133] [D fake: 0.343082]
[Epoch 45/50] [Batch 299/938] [D loss: 0.954293] [G loss: 1.709867] [D real: 0.805845] [D fake: 0.466294]
[Epoch 45/50] [Batch 599/938] [D loss: 0.932705] [G loss: 1.283069] [D real: 0.612603] [D fake: 0.245902]
[Epoch 45/50] [Batch 899/938] [D loss: 1.107646] [G loss: 1.293812] [D real: 0.543919] [D fake: 0.228042]
[Epoch 46/50] [Batch 299/938] [D loss: 0.867950] [G loss: 1.428596] [D real: 0.673719] [D fake: 0.282370]
[Epoch 46/50] [Batch 599/938] [D loss: 1.078684] [G loss: 1.340557] [D real: 0.553742] [D fake: 0.248088]
[Epoch 46/50] [Batch 899/938] [D loss: 1.066378] [G loss: 1.557076] [D real: 0.690389] [D fake: 0.396920]
[Epoch 47/50] [Batch 299/938] [D loss: 0.871799] [G loss: 1.392173] [D real: 0.674179] [D fake: 0.292265]
[Epoch 47/50] [Batch 599/938] [D loss: 0.981627] [G loss: 1.588988] [D real: 0.742561] [D fake: 0.395672]
[Epoch 47/50] [Batch 899/938] [D loss: 1.032251] [G loss: 0.972809] [D real: 0.672172] [D fake: 0.393318]
[Epoch 48/50] [Batch 299/938] [D loss: 1.006597] [G loss: 1.520477] [D real: 0.724035] [D fake: 0.427500]
[Epoch 48/50] [Batch 599/938] [D loss: 0.980703] [G loss: 1.298721] [D real: 0.687361] [D fake: 0.374263]
[Epoch 48/50] [Batch 899/938] [D loss: 1.326660] [G loss: 0.897566] [D real: 0.443015] [D fake: 0.212538]
[Epoch 49/50] [Batch 299/938] [D loss: 0.924468] [G loss: 1.261113] [D real: 0.637928] [D fake: 0.281026]
[Epoch 49/50] [Batch 599/938] [D loss: 1.018169] [G loss: 1.674027] [D real: 0.746008] [D fake: 0.449878]
[Epoch 49/50] [Batch 899/938] [D loss: 1.034223] [G loss: 1.723830] [D real: 0.720788] [D fake: 0.420542]

3.保存模型

torch.save(generator.state_dict(),'./save/generator.pth')
torch.save(discriminator.state_dict(),'./save/discriminator.pth')

在本次GAN的学习中,博弈方法的引入让我耳目一新,这种随机生成噪声再不断修正的方法感觉就像是我自己的学习一样。看着厚厚的培养方案,不知道怎么下手,只有自己先瞎走一通,才能找到自己的方向。这周先是理论学习,感觉知识点挺多的,加上开学事情又多了起来。很难抽出时间来学习,只有先大致掌握原理,下周把事情忙完了再来调试。

这是我自己去找到的一小部分关于GAN的介绍
对抗生成网络(GAN)的要点可以从以下几个方面进行详细总结:

  1. 基本原理:GAN由两个关键组件构成,即生成器(Generator)和判别器(Discriminator)。生成器的目标是生成尽可能真实的数据以欺骗判别器,而判别器的目标则是准确地区分出真实数据和生成数据。
  2. 工作流程:生成器接收一个随机噪声作为输入,并通过这个噪声生成数据。判别器则接收真实数据和生成器生成的数据,其任务是判断输入数据是真实的还是由生成器产生的。
  3. 训练过程:在训练过程中,生成器和判别器进行动态的“博弈过程”。生成器不断尝试提高生成数据的质量,以更好地欺骗判别器;判别器则不断提高识别能力,以更准确地分辨真实数据和生成数据。
  4. 模型优化:为了提高GAN的性能,研究人员会尝试各种数据增强技巧,如图像的裁切、变形、调整明暗和颜色等,以提高模型在图像分类和检测任务中的精度。
  5. 核心概念:GAN的核心在于通过对抗学习的方式,使得生成器能够学习到真实数据的分布,从而产生高质量的合成数据。
  6. 应用范围:GAN的应用非常广泛,包括但不限于图像合成、风格迁移、图像到图像的转换、药物发现、语音合成等领域。
  7. 挑战与改进:尽管GAN在多个领域表现出色,但在实际应用中仍面临一些挑战,如模式崩溃、训练不稳定等问题。因此,研究人员不断提出新的架构和训练策略,如条件GAN、循环GAN、Wasserstein
    GAN等,以解决这些问题。
  8. 未来发展:随着研究的深入和技术的进步,GAN有望在更多领域发挥作用,同时也可能出现更多创新的变体,以满足不断增长的应用需求。

总的来说,GAN作为一种强大的生成模型,其核心在于通过生成器和判别器的对抗学习过程,不断提升生成数据的质量。在实际应用中,GAN的优化和改进是一个持续的过程,旨在解决训练中的挑战并扩展其应用范围。

在训练模型的那段代码中我感觉是GAN最核心的地方,不断的生成和对比,还要以标签作为标准,反复去计算损失,更新参数,让判别器确信生成的图片的真实性,从而达到对抗的效果。
在长久的学习以来我一直感受到随机噪声的亮点,就像人的发展一样,我们想尽办法,去创造一个熵减的过程去获得确定性。这是我的一个很直观的感受,由于时间有限,下周我再好好说说我的学习心得。

调整生成对抗网络(GAN)的训练参数是一个复杂的过程,以下是一些有效的调整策略:

初始化:可以通过训练一个变分自编码器(VAE)并使用其解码器权重来初始化生成器(generator)。这样做可以帮助生成器在开始时就具备一定的生成能力,从而提高训练的效率和成功率。

交叉训练:在训练过程中,可以采用不对称的训练方式,即每训练一次判别器(discriminator),训练多次生成器。这样可以让生成器在早期阶段获得更多的学习机会。

修改损失函数:可以尝试使用不同的损失函数来优化生成器和判别器。例如,使用带有梯度惩罚的损失函数可以帮助稳定训练过程,防止梯度爆炸或消失。

选择数据集:选择一个合适的数据集对于训练GAN至关重要。数据集应该包含足够多样化的样本,以便训练出能够生成高质量图像的GAN模型。

调整超参数:超参数如学习率、批次大小、优化器的选择等都会影响GAN的训练。需要根据具体情况进行调整,以找到最佳的训练配置。

使用GAN变体:有许多GAN的变体,如半监督生成对抗网络(SGAN)、边界搜索生成对抗网络(BGAN)等,它们在原始GAN的基础上做了改进,可以根据实际情况选择适合的模型架构。

监控训练过程:在训练过程中,应该密切监控生成器和判别器的性能,以及生成图像的质量。这有助于及时发现问题并调整训练策略。

优先训练判别器:在某些情况下,先训练判别器直到它能够很好地区分真实和生成的样本,然后再引入生成器的训练,这种方法可以提高训练的稳定性。

使用标签翻转:在训练生成器时,可以将生成的数据当作真实数据来训练判别器,即使用标签翻转的技巧,这样可以更有效地优化生成器。

资源内容为GAN对抗神经网络的各种常用变体:具体内容包括GAN(经典)、半监督生成对抗网络(SGAN)、边界搜索生成对抗网络(BGAN)、对偶生成对抗网络(DualGAN)、辅助分类-生成对抗网络(Auxiliary
Classifier GAN)等。

  • 20
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值