- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊 | 接辅导、项目定制
一、前期准备
1.定义超参数
代码知识点
这部分代码导入了所需的库和模块,包括
argparse用于解析命令行参数,
os用于处理文件路径,
numpy用于数值计算,
torchvision.transforms用于图像变换,
torchvision.utils中的save_image用于保存生成的图像,
torch.utils.data中的DataLoader用于加载数据,
torchvision中的datasets用于加载数据集,
torch.autograd中的Variable用于自动求导,
torch.nn用于构建神经网络模型,以及torch用于张量操作。
这些参数会影响生成对抗网络(GAN)的训练过程,具体如下:
n_epochs:这个参数决定了模型训练的总轮数。轮数越多,模型有更多机会学习数据中的模式,但也可能导致过拟合。
batch_size:批次大小影响模型每次更新时使用的数据量。较小的批次可能导致训练过程波动较大,但可能有助于模型逃离局部最小值;较大的批次则可能使训练更稳定,但需要更多的内存空间。
lr:学习率控制着模型权重更新的步长。学习率过大可能导致模型在最优解附近震荡甚至发散;学习率过小则可能导致模型收敛速度缓慢或陷入局部最小值。
b1和b2:这两个参数是Adam优化器的一部分,分别控制一阶矩(梯度的指数移动平均)和二阶矩(梯度平方的指数移动平均)的指数衰减率。它们影响模型更新的稳定性和收敛速度。
n_cpu:这个参数指定了用于数据加载的CPU数量,可以影响数据预处理和加载的速度,进而影响训练的效率。
latent_dim:随机向量的维度,它影响生成器生成图像的多样性和质量。维度过低可能导致生成图像缺乏多样性,而维度过高可能导致模型难以训练。
img_size:图像的大小直接影响模型的感受野和所需计算资源。图像尺寸越大,模型可能需要更多的计算资源和更长的训练时间。
channels:图像的通道数,对于彩色图像通常是3(RGB),对于灰度图像是1。通道数影响模型处理的信息量。
sample_interval:保存生成图像的间隔,这个参数决定了我们在训练过程中多久保存一次生成的图像,用于监控生成图像的质量。
cuda:是否使用GPU进行计算,使用GPU可以显著加速模型的训练过程,因为GPU在并行处理大量计算时更为高效。
import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch
os.makedirs("./images/",exist_ok=True)
os.makedirs("./save/",exist_ok=True)
os.makedirs("./datasets/mnist",exist_ok=True)
n_epochs=50
batch_size=64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500
img_shape =(channels,img_size,img_size)
img_area = np.prod(img_shape)
cuda = True if torch.cuda.is_available() else False
print(cuda)
输出
True
2.下载数据
mnist = datasets.MNIST(
root = './datasets/', train=True, download=True, transform=transforms.Compose(
[transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
)
3.配置数据
dataloader =DataLoader(
mnist,
batch_size=batch_size,
shuffle=True,
)
二、定义模型
1.定义辨别器
代码知识点
这段代码定义了一个名为Discriminator
的类,它继承自nn.Module
。这个类是一个判别器模型,用于判断输入图像是否为真实图像。下面是对代码中每一行的详细解释:
class Discriminator(nn.Module):
:定义一个名为Discriminator
的类,它继承自nn.Module
。nn.Module
是PyTorch中的一个基类,用于构建神经网络模型。
def __init__(self):
:定义类的构造函数,用于初始化模型的参数和层。
super(Discriminator,self).__init__()
:调用父类nn.Module
的构造函数,以确保正确地初始化模型。
self.model = nn.Sequential(
:创建一个nn.Sequential
对象,它是一个容器,用于按顺序堆叠多个神经网络层。
nn.Linear(img_area,512),
:添加一个线性层,输入大小为img_area
(图像区域的像素数),输出大小为512。这个层用于将输入图像展平并映射到一个新的特征空间。
nn.LeakyReLU(0.2,inplace=True),
:添加一个Leaky ReLU激活函数,其负斜率为0.2。inplace=True
表示在原始数据上进行操作,以节省内存。
nn.Linear(512,256),
:添加一个线性层,输入大小为512,输出大小为256。这个层用于进一步将特征映射到更小的特征空间。
nn.LeakyReLU(0.2,inplace=True),
:再次添加一个Leaky ReLU激活函数,与之前的层相同。
nn.Linear(256,1),
:添加一个线性层,输入大小为256,输出大小为1。这个层用于将特征映射到一个标量值,用于表示输入图像的真实性。
nn.Sigmoid(),
:添加一个Sigmoid激活函数,将输出值限制在0到1之间。这可以解释为输入图像为真实图像的概率。
)
:结束nn.Sequential
对象的创建。
def forward(self, img):
:定义模型的前向传播函数,用于计算输入图像的输出。
img_flat = img.view(img.size(0),-1)
:将输入图像img
展平为一个一维向量。img.size(0)
表示批量大小,-1
表示自动计算剩余维度的大小。
validity = self.model(img_flat)
:将展平后的图像传递给之前定义的nn.Sequential
模型,得到一个表示图像真实性的标量值。
return validity
:返回计算得到的图像真实性值。
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator,self).__init__()
self.model = nn.Sequential(
nn.Linear(img_area,512),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(512,256),
nn.LeakyReLU(0.2,inplace=True),
nn.Linear(256,1),
nn.Sigmoid(),
)
def forward(self, img):
img_flat = img.view(img.size(0),-1)
validity = self.model(img_flat)
return validity
2.定义生成器
代码知识点
class Generator(nn.Module):
:定义一个名为Generator的类,继承自PyTorch的nn.Module模块。def __init__ (self):
:定义类的初始化方法。super(Generator,self).__init__()
:调用父类的初始化方法。def block(in_feat, out_feat, normalize=True):
:定义一个名为block的内部函数,用于构建生成器网络的每个块。layers = [nn.Linear(in_feat, out_feat)]
:创建一个线性层,输入特征数为in_feat,输出特征数为out_feat。if normalize:
:判断是否需要进行批量归一化。layers.append(nn.BatchNorm1d(out_feat, 0.8))
:添加批量归一化层,输出特征数为out_feat,动量参数为0.8。layers.append(nn.LeakyReLU(0.2,inplace=True))
:添加Leaky ReLU激活函数,负斜率为0.2,inplace参数设置为True表示直接修改输入数据。return layers
:返回构建好的层列表。self.model = nn.Sequential(
:创建一个顺序模型,用于堆叠生成器的各个层。*block(latent_dim,128,normalize=False),
:调用block函数构建第一个块,输入特征数为latent_dim,输出特征数为128,不进行批量归一化。*block(128,256),
:调用block函数构建第二个块,输入特征数为128,输出特征数为256。*block(256,512),
:调用block函数构建第三个块,输入特征数为256,输出特征数为512。*block(512,1024),
:调用block函数构建第四个块,输入特征数为512,输出特征数为1024。nn.Linear(1024,img_area),
:添加一个线性层,输入特征数为1024,输出特征数为img_area。nn.Tanh()
:添加一个Tanh激活函数,将输出值映射到-1到1之间。)
:结束顺序模型的定义。def forward(self,z):
:定义前向传播方法,输入为z。imgs = self.model(z)
:将输入z传入模型,得到输出imgs。imgs = imgs.view(imgs.size(0),*img_shape)
:将输出imgs的形状调整为(batch_size,
*img_shape)。return imgs
:返回调整形状后的imgs。
class Generator(nn.Module):
def __init__ (self):
super(Generator,self).__init__()
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)]
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
layers.append(nn.LeakyReLU(0.2,inplace=True))
return layers
## prod():返回给定轴上的数组元素的乘积:1*28*28=784
self.model = nn.Sequential(
*block(latent_dim,128,normalize=False),
*block(128,256),
*block(256,512),
*block(512,1024),
nn.Linear(1024,img_area),
nn.Tanh()
)
## view():相当于numpy中的reshape,重新定义矩阵的形状
def forward(self,z):
imgs = self.model(z)
imgs = imgs.view(imgs.size(0),*img_shape)
return imgs
三、训练模型
1.创建实例
代码知识点
generator = Generator()
:创建一个名为generator的生成器对象,该对象是Generator类的实
例。discriminator = Discriminator()
:创建一个名为discriminator的判别器对象,该对象是Discriminator类的实例。criterion = torch.nn.BCELoss()
:创建一个名为criterion的损失函数对象,该对象是二元交叉熵损失(Binary Cross
Entropy Loss)的实例。optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr, betas=(b1,b2))
:创建一个名为optimizer_G的优化器对象,用于优化生成器网络的参数。该对象是Adam优化器的实例,学习率为lr,beta1和beta2分别为b1和b2。optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr, betas=(b1,b2))
:创建一个名为optimizer_D的优化器对象,用于优化判别器网络的参数。该对象是Adam优化器的实例,学习率为lr,beta1和beta2分别为b1和b2。if torch.cuda.is_available():
:判断当前设备是否支持CUDA加速。generator = generator.cuda()
:如果支持CUDA加速,将生成器对象转移到GPU上进行计算。discriminator = discriminator.cuda()
:如果支持CUDA加速,将判别器对象转移到GPU上进行计算。criterion = criterion.cuda()
:如果支持CUDA加速,将损失函数对象转移到GPU上进行计算。
generator = Generator()
discriminator = Discriminator()
criterion = torch.nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(),lr=lr, betas=(b1,b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(),lr=lr, betas=(b1,b2))
if torch.cuda.is_available():
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
2.训练模型
代码知识点
for epoch in range(n_epochs):
:开始一个循环,循环次数为n_epochs,每次循环代表一个训练周期。for i,(imgs, _) in enumerate(dataloader):
:在每个训练周期内,遍历数据加载器(dataloader),获取图像数据和标签。imgs = imgs.view(imgs.size(0), -1)
:将图像数据调整为二维张量,其中第一维是批次大小,第二维是图像的展平表示。real_img = Variable(imgs).cuda()
:将真实的图像数据转换为变量并移动到GPU上进行计算。real_label = Variable(torch.ones(imgs.size(0),1)).cuda()
:创建一个全为1的标签向量,用于判别器的真实图像损失计算。fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()
:创建一个全为0的标签向量,用于判别器的生成图像损失计算。real_out = discriminator(real_img)
:将真实图像输入判别器,得到判别器对真实图像的输出。loss_real_D = criterion(real_out,real_label)
:计算判别器对真实图像的损失。real_scores = real_out
:将判别器对真实图像的输出保存为真实分数。z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
:生成随机噪声向量z。fake_img = generator(z).detach()
:将噪声向量z输入生成器,得到生成的图像,并将其从计算图中分离。fake_out = discriminator(fake_img)
:将生成的图像输入判别器,得到判别器对生成图像的输出。loss_fake_D = criterion(fake_out,fake_label)
:计算判别器对生成图像的损失。fake_scores = fake_out
:将判别器对生成图像的输出保存为假分数。loss_D = loss_real_D + loss_fake_D
:计算判别器的总损失,即真实图像损失和生成图像损失之和。optimizer_D.zero_grad()
:将判别器的梯度清零。loss_D.backward()
:计算判别器的损失关于参数的梯度。optimizer_D.step()
:更新判别器的参数。z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
:再次生成随机噪声向量z。fake_img = generator(z)
:将噪声向量z输入生成器,得到生成的图像。output = discriminator(fake_img)
:将生成的图像输入判别器,得到判别器对生成图像的输出。loss_G = criterion(output,real_label)
:计算生成器的损失。optimizer_G.zero_grad()
:将生成器的梯度清零。loss_G.backward()
:计算生成器的损失关于参数的梯度。optimizer_G.step()
:更新生成器的参数。if (i+1) % 300 == 0:
:每隔300个批次打印一次训练信息。print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]" % (epoch,n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean()))
:打印当前训练周期、批次、判别器损失、生成器损失、判别器对真实图像的平均分数和判别器对生成图像的平均分数。batches_done = epoch * len(dataloader) + i
:计算已完成的批次数。if batches_done % sample_interval == 0:
:每隔一定间隔保存生成的图像。save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow=5, normalize=True)
:保存前25个生成的图像到指定路径,并进行归一化处理。
for epoch in range(n_epochs):
for i,(imgs, _) in enumerate(dataloader):
imgs = imgs.view(imgs.size(0), -1)
real_img = Variable(imgs).cuda()
real_label = Variable(torch.ones(imgs.size(0),1)).cuda()
fake_label = Variable(torch.zeros(imgs.size(0),1)).cuda()
real_out = discriminator(real_img)
loss_real_D = criterion(real_out,real_label)
real_scores = real_out
z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
fake_img = generator(z).detach()
fake_out = discriminator(fake_img)
loss_fake_D = criterion(fake_out,fake_label)
fake_scores = fake_out
loss_D = loss_real_D + loss_fake_D
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
z = Variable(torch.randn(imgs.size(0),latent_dim)).cuda()
fake_img = generator(z)
output = discriminator(fake_img)
loss_G = criterion(output,real_label)
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
if (i+1) % 300 == 0:
print(
"[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
% (epoch,n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
)
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
save_image(fake_img.data[:25],"./images/%d.png" % batches_done, nrow=5, normalize=True)
输出
[Epoch 0/50] [Batch 299/938] [D loss: 1.486548] [G loss: 1.263540] [D real: 0.796954] [D fake: 0.705474]
[Epoch 0/50] [Batch 599/938] [D loss: 1.022943] [G loss: 0.799070] [D real: 0.472179] [D fake: 0.185987]
[Epoch 0/50] [Batch 899/938] [D loss: 1.116177] [G loss: 1.078084] [D real: 0.567511] [D fake: 0.387150]
[Epoch 1/50] [Batch 299/938] [D loss: 1.492023] [G loss: 2.803232] [D real: 0.912035] [D fake: 0.741440]
[Epoch 1/50] [Batch 599/938] [D loss: 1.088952] [G loss: 1.283060] [D real: 0.638842] [D fake: 0.424129]
[Epoch 1/50] [Batch 899/938] [D loss: 1.045345] [G loss: 1.928071] [D real: 0.777452] [D fake: 0.519182]
[Epoch 2/50] [Batch 299/938] [D loss: 0.924773] [G loss: 1.558642] [D real: 0.691493] [D fake: 0.375394]
[Epoch 2/50] [Batch 599/938] [D loss: 0.859401] [G loss: 1.168849] [D real: 0.610856] [D fake: 0.238899]
[Epoch 2/50] [Batch 899/938] [D loss: 0.994440] [G loss: 0.886017] [D real: 0.506867] [D fake: 0.113889]
[Epoch 3/50] [Batch 299/938] [D loss: 1.111854] [G loss: 0.959417] [D real: 0.470583] [D fake: 0.124310]
[Epoch 3/50] [Batch 599/938] [D loss: 0.795403] [G loss: 1.288844] [D real: 0.660531] [D fake: 0.231165]
[Epoch 3/50] [Batch 899/938] [D loss: 0.916148] [G loss: 2.614962] [D real: 0.825633] [D fake: 0.481356]
[Epoch 4/50] [Batch 299/938] [D loss: 0.884143] [G loss: 2.002452] [D real: 0.764259] [D fake: 0.402012]
[Epoch 4/50] [Batch 599/938] [D loss: 0.741833] [G loss: 2.512850] [D real: 0.797984] [D fake: 0.330773]
[Epoch 4/50] [Batch 899/938] [D loss: 0.778115] [G loss: 1.621227] [D real: 0.742208] [D fake: 0.324576]
[Epoch 5/50] [Batch 299/938] [D loss: 0.764483] [G loss: 1.628026] [D real: 0.764330] [D fake: 0.341625]
[Epoch 5/50] [Batch 599/938] [D loss: 0.863362] [G loss: 2.213945] [D real: 0.813163] [D fake: 0.430946]
[Epoch 5/50] [Batch 899/938] [D loss: 0.816530] [G loss: 1.191403] [D real: 0.599072] [D fake: 0.124816]
[Epoch 6/50] [Batch 299/938] [D loss: 0.893659] [G loss: 2.667447] [D real: 0.821634] [D fake: 0.457403]
[Epoch 6/50] [Batch 599/938] [D loss: 0.937223] [G loss: 0.625343] [D real: 0.553249] [D fake: 0.183966]
[Epoch 6/50] [Batch 899/938] [D loss: 0.858882] [G loss: 1.268875] [D real: 0.662991] [D fake: 0.270400]
[Epoch 7/50] [Batch 299/938] [D loss: 0.920097] [G loss: 1.110757] [D real: 0.662119] [D fake: 0.310955]
[Epoch 7/50] [Batch 599/938] [D loss: 0.848626] [G loss: 1.869307] [D real: 0.794307] [D fake: 0.407518]
[Epoch 7/50] [Batch 899/938] [D loss: 0.894683] [G loss: 1.515777] [D real: 0.668894] [D fake: 0.315810]
[Epoch 8/50] [Batch 299/938] [D loss: 0.945508] [G loss: 1.450796] [D real: 0.684737] [D fake: 0.359553]
[Epoch 8/50] [Batch 599/938] [D loss: 0.893079] [G loss: 1.032745] [D real: 0.595712] [D fake: 0.202011]
[Epoch 8/50] [Batch 899/938] [D loss: 1.158379] [G loss: 0.836311] [D real: 0.476757] [D fake: 0.137825]
[Epoch 9/50] [Batch 299/938] [D loss: 1.194173] [G loss: 1.441912] [D real: 0.715668] [D fake: 0.492410]
[Epoch 9/50] [Batch 599/938] [D loss: 0.977941] [G loss: 1.354820] [D real: 0.690165] [D fake: 0.390519]
[Epoch 9/50] [Batch 899/938] [D loss: 1.047443] [G loss: 0.890079] [D real: 0.501754] [D fake: 0.143151]
[Epoch 10/50] [Batch 299/938] [D loss: 0.928828] [G loss: 1.659332] [D real: 0.735775] [D fake: 0.406559]
[Epoch 10/50] [Batch 599/938] [D loss: 0.805291] [G loss: 1.204942] [D real: 0.677094] [D fake: 0.243433]
[Epoch 10/50] [Batch 899/938] [D loss: 1.082386] [G loss: 0.736025] [D real: 0.518290] [D fake: 0.203062]
[Epoch 11/50] [Batch 299/938] [D loss: 1.207286] [G loss: 2.173785] [D real: 0.823033] [D fake: 0.567325]
[Epoch 11/50] [Batch 599/938] [D loss: 0.864652] [G loss: 1.226438] [D real: 0.670658] [D fake: 0.292929]
[Epoch 11/50] [Batch 899/938] [D loss: 1.067761] [G loss: 0.816991] [D real: 0.557894] [D fake: 0.265028]
[Epoch 12/50] [Batch 299/938] [D loss: 1.094701] [G loss: 0.828730] [D real: 0.478709] [D fake: 0.152698]
[Epoch 12/50] [Batch 599/938] [D loss: 0.923310] [G loss: 1.575744] [D real: 0.673793] [D fake: 0.341787]
[Epoch 12/50] [Batch 899/938] [D loss: 0.882771] [G loss: 1.570408] [D real: 0.721132] [D fake: 0.348645]
[Epoch 13/50] [Batch 299/938] [D loss: 1.213910] [G loss: 1.739210] [D real: 0.816769] [D fake: 0.581116]
[Epoch 13/50] [Batch 599/938] [D loss: 0.986962] [G loss: 2.176203] [D real: 0.801824] [D fake: 0.485661]
[Epoch 13/50] [Batch 899/938] [D loss: 1.123384] [G loss: 1.887434] [D real: 0.788300] [D fake: 0.526488]
[Epoch 14/50] [Batch 299/938] [D loss: 0.993780] [G loss: 1.074014] [D real: 0.621304] [D fake: 0.302661]
[Epoch 14/50] [Batch 599/938] [D loss: 1.139293] [G loss: 0.902772] [D real: 0.486842] [D fake: 0.136740]
[Epoch 14/50] [Batch 899/938] [D loss: 1.039376] [G loss: 0.765984] [D real: 0.579033] [D fake: 0.273950]
[Epoch 15/50] [Batch 299/938] [D loss: 0.957737] [G loss: 1.730967] [D real: 0.735324] [D fake: 0.417273]
[Epoch 15/50] [Batch 599/938] [D loss: 1.133914] [G loss: 1.851659] [D real: 0.768372] [D fake: 0.543797]
[Epoch 15/50] [Batch 899/938] [D loss: 0.953224] [G loss: 1.377125] [D real: 0.664915] [D fake: 0.350922]
[Epoch 16/50] [Batch 299/938] [D loss: 0.970558] [G loss: 1.432826] [D real: 0.672650] [D fake: 0.376319]
[Epoch 16/50] [Batch 599/938] [D loss: 0.917780] [G loss: 1.423335] [D real: 0.716623] [D fake: 0.372080]
[Epoch 16/50] [Batch 899/938] [D loss: 1.034982] [G loss: 1.312837] [D real: 0.676499] [D fake: 0.388286]
[Epoch 17/50] [Batch 299/938] [D loss: 1.283185] [G loss: 2.052758] [D real: 0.843837] [D fake: 0.627986]
[Epoch 17/50] [Batch 599/938] [D loss: 1.014387] [G loss: 1.197765] [D real: 0.630091] [D fake: 0.338972]
[Epoch 17/50] [Batch 899/938] [D loss: 0.953568] [G loss: 1.231606] [D real: 0.657918] [D fake: 0.351649]
[Epoch 18/50] [Batch 299/938] [D loss: 1.032209] [G loss: 1.185335] [D real: 0.588554] [D fake: 0.312889]
[Epoch 18/50] [Batch 599/938] [D loss: 0.938251] [G loss: 1.347066] [D real: 0.614758] [D fake: 0.287012]
[Epoch 18/50] [Batch 899/938] [D loss: 1.098345] [G loss: 1.045380] [D real: 0.670897] [D fake: 0.438555]
[Epoch 19/50] [Batch 299/938] [D loss: 1.041932] [G loss: 1.125989] [D real: 0.623168] [D fake: 0.349922]
[Epoch 19/50] [Batch 599/938] [D loss: 1.049016] [G loss: 0.997303] [D real: 0.560449] [D fake: 0.289190]
[Epoch 19/50] [Batch 899/938] [D loss: 1.109702] [G loss: 0.830619] [D real: 0.554481] [D fake: 0.302520]
[Epoch 20/50] [Batch 299/938] [D loss: 1.091989] [G loss: 1.078110] [D real: 0.579322] [D fake: 0.306465]
[Epoch 20/50] [Batch 599/938] [D loss: 1.075116] [G loss: 1.026978] [D real: 0.646824] [D fake: 0.379286]
[Epoch 20/50] [Batch 899/938] [D loss: 1.012776] [G loss: 1.252582] [D real: 0.650999] [D fake: 0.355130]
[Epoch 21/50] [Batch 299/938] [D loss: 1.019206] [G loss: 1.080652] [D real: 0.571173] [D fake: 0.282273]
[Epoch 21/50] [Batch 599/938] [D loss: 1.092633] [G loss: 0.927598] [D real: 0.527832] [D fake: 0.257436]
[Epoch 21/50] [Batch 899/938] [D loss: 1.221342] [G loss: 1.074707] [D real: 0.572400] [D fake: 0.380784]
[Epoch 22/50] [Batch 299/938] [D loss: 1.065374] [G loss: 0.797963] [D real: 0.484218] [D fake: 0.195965]
[Epoch 22/50] [Batch 599/938] [D loss: 1.105837] [G loss: 0.691518] [D real: 0.514260] [D fake: 0.276405]
[Epoch 22/50] [Batch 899/938] [D loss: 0.878742] [G loss: 1.166016] [D real: 0.724698] [D fake: 0.377328]
[Epoch 23/50] [Batch 299/938] [D loss: 1.011820] [G loss: 1.244210] [D real: 0.713064] [D fake: 0.429894]
[Epoch 23/50] [Batch 599/938] [D loss: 1.073738] [G loss: 1.664922] [D real: 0.719866] [D fake: 0.461699]
[Epoch 23/50] [Batch 899/938] [D loss: 1.009477] [G loss: 1.276478] [D real: 0.616435] [D fake: 0.336026]
[Epoch 24/50] [Batch 299/938] [D loss: 1.020366] [G loss: 1.251225] [D real: 0.648893] [D fake: 0.388956]
[Epoch 24/50] [Batch 599/938] [D loss: 1.101402] [G loss: 2.135847] [D real: 0.786404] [D fake: 0.536431]
[Epoch 24/50] [Batch 899/938] [D loss: 1.202906] [G loss: 0.715893] [D real: 0.452851] [D fake: 0.163171]
[Epoch 25/50] [Batch 299/938] [D loss: 1.140191] [G loss: 0.871201] [D real: 0.527046] [D fake: 0.297866]
[Epoch 25/50] [Batch 599/938] [D loss: 1.117683] [G loss: 1.204782] [D real: 0.623783] [D fake: 0.375783]
[Epoch 25/50] [Batch 899/938] [D loss: 1.122934] [G loss: 1.279258] [D real: 0.694599] [D fake: 0.453740]
[Epoch 26/50] [Batch 299/938] [D loss: 1.059794] [G loss: 0.850840] [D real: 0.512065] [D fake: 0.202767]
[Epoch 26/50] [Batch 599/938] [D loss: 1.072658] [G loss: 1.585017] [D real: 0.740615] [D fake: 0.476374]
[Epoch 26/50] [Batch 899/938] [D loss: 1.177186] [G loss: 0.832109] [D real: 0.472754] [D fake: 0.196632]
[Epoch 27/50] [Batch 299/938] [D loss: 0.986746] [G loss: 1.037914] [D real: 0.552875] [D fake: 0.219123]
[Epoch 27/50] [Batch 599/938] [D loss: 1.001620] [G loss: 1.654807] [D real: 0.745383] [D fake: 0.430741]
[Epoch 27/50] [Batch 899/938] [D loss: 1.007908] [G loss: 0.997429] [D real: 0.587789] [D fake: 0.299459]
[Epoch 28/50] [Batch 299/938] [D loss: 1.002784] [G loss: 1.241717] [D real: 0.709475] [D fake: 0.420782]
[Epoch 28/50] [Batch 599/938] [D loss: 1.058915] [G loss: 1.121608] [D real: 0.660302] [D fake: 0.398084]
[Epoch 28/50] [Batch 899/938] [D loss: 0.949169] [G loss: 0.922655] [D real: 0.599296] [D fake: 0.259329]
[Epoch 29/50] [Batch 299/938] [D loss: 1.099054] [G loss: 0.816260] [D real: 0.549399] [D fake: 0.305472]
[Epoch 29/50] [Batch 599/938] [D loss: 1.008659] [G loss: 1.256130] [D real: 0.737606] [D fake: 0.429921]
[Epoch 29/50] [Batch 899/938] [D loss: 1.039975] [G loss: 1.456424] [D real: 0.730297] [D fake: 0.425789]
[Epoch 30/50] [Batch 299/938] [D loss: 1.063697] [G loss: 1.032279] [D real: 0.618258] [D fake: 0.364028]
[Epoch 30/50] [Batch 599/938] [D loss: 1.114642] [G loss: 1.696752] [D real: 0.735125] [D fake: 0.477111]
[Epoch 30/50] [Batch 899/938] [D loss: 1.077305] [G loss: 1.105972] [D real: 0.530655] [D fake: 0.231420]
[Epoch 31/50] [Batch 299/938] [D loss: 1.011994] [G loss: 1.358363] [D real: 0.667097] [D fake: 0.406712]
[Epoch 31/50] [Batch 599/938] [D loss: 0.991479] [G loss: 1.544970] [D real: 0.680855] [D fake: 0.399283]
[Epoch 31/50] [Batch 899/938] [D loss: 1.009350] [G loss: 1.655224] [D real: 0.741437] [D fake: 0.454083]
[Epoch 32/50] [Batch 299/938] [D loss: 1.009774] [G loss: 1.154610] [D real: 0.697059] [D fake: 0.417741]
[Epoch 32/50] [Batch 599/938] [D loss: 1.131516] [G loss: 1.604385] [D real: 0.799639] [D fake: 0.521479]
[Epoch 32/50] [Batch 899/938] [D loss: 0.971643] [G loss: 1.751588] [D real: 0.769394] [D fake: 0.447989]
[Epoch 33/50] [Batch 299/938] [D loss: 0.937302] [G loss: 1.137668] [D real: 0.667549] [D fake: 0.352494]
[Epoch 33/50] [Batch 599/938] [D loss: 1.142665] [G loss: 1.605705] [D real: 0.736860] [D fake: 0.494053]
[Epoch 33/50] [Batch 899/938] [D loss: 0.975630] [G loss: 1.081744] [D real: 0.617061] [D fake: 0.311667]
[Epoch 34/50] [Batch 299/938] [D loss: 1.004504] [G loss: 1.199623] [D real: 0.584345] [D fake: 0.250685]
[Epoch 34/50] [Batch 599/938] [D loss: 1.111552] [G loss: 1.399492] [D real: 0.683633] [D fake: 0.437264]
[Epoch 34/50] [Batch 899/938] [D loss: 1.057444] [G loss: 1.816881] [D real: 0.718745] [D fake: 0.452988]
[Epoch 35/50] [Batch 299/938] [D loss: 1.072661] [G loss: 1.064296] [D real: 0.542904] [D fake: 0.252511]
[Epoch 35/50] [Batch 599/938] [D loss: 1.006217] [G loss: 1.576730] [D real: 0.748175] [D fake: 0.426752]
[Epoch 35/50] [Batch 899/938] [D loss: 1.006518] [G loss: 1.166273] [D real: 0.663842] [D fake: 0.369748]
[Epoch 36/50] [Batch 299/938] [D loss: 1.074173] [G loss: 0.836063] [D real: 0.564981] [D fake: 0.312169]
[Epoch 36/50] [Batch 599/938] [D loss: 1.115033] [G loss: 1.369953] [D real: 0.592239] [D fake: 0.329859]
[Epoch 36/50] [Batch 899/938] [D loss: 1.022781] [G loss: 1.326988] [D real: 0.578367] [D fake: 0.262427]
[Epoch 37/50] [Batch 299/938] [D loss: 0.978686] [G loss: 1.245194] [D real: 0.673857] [D fake: 0.346784]
[Epoch 37/50] [Batch 599/938] [D loss: 1.157297] [G loss: 1.436168] [D real: 0.660286] [D fake: 0.434538]
[Epoch 37/50] [Batch 899/938] [D loss: 0.866226] [G loss: 1.513450] [D real: 0.741742] [D fake: 0.375720]
[Epoch 38/50] [Batch 299/938] [D loss: 1.042357] [G loss: 1.409830] [D real: 0.776054] [D fake: 0.474026]
[Epoch 38/50] [Batch 599/938] [D loss: 0.934915] [G loss: 1.663973] [D real: 0.702456] [D fake: 0.368171]
[Epoch 38/50] [Batch 899/938] [D loss: 1.059174] [G loss: 1.293531] [D real: 0.617721] [D fake: 0.317426]
[Epoch 39/50] [Batch 299/938] [D loss: 0.964760] [G loss: 1.176818] [D real: 0.639019] [D fake: 0.328120]
[Epoch 39/50] [Batch 599/938] [D loss: 1.327827] [G loss: 2.243500] [D real: 0.822750] [D fake: 0.608486]
[Epoch 39/50] [Batch 899/938] [D loss: 0.964594] [G loss: 1.032648] [D real: 0.629228] [D fake: 0.300384]
[Epoch 40/50] [Batch 299/938] [D loss: 1.098591] [G loss: 0.879305] [D real: 0.577749] [D fake: 0.312382]
[Epoch 40/50] [Batch 599/938] [D loss: 0.914958] [G loss: 1.446056] [D real: 0.676897] [D fake: 0.326900]
[Epoch 40/50] [Batch 899/938] [D loss: 0.959015] [G loss: 1.476370] [D real: 0.706569] [D fake: 0.380073]
[Epoch 41/50] [Batch 299/938] [D loss: 1.369267] [G loss: 2.306901] [D real: 0.841119] [D fake: 0.601469]
[Epoch 41/50] [Batch 599/938] [D loss: 0.934483] [G loss: 1.628165] [D real: 0.624097] [D fake: 0.244898]
[Epoch 41/50] [Batch 899/938] [D loss: 1.014262] [G loss: 1.479493] [D real: 0.729656] [D fake: 0.417756]
[Epoch 42/50] [Batch 299/938] [D loss: 1.011384] [G loss: 1.144827] [D real: 0.611724] [D fake: 0.298554]
[Epoch 42/50] [Batch 599/938] [D loss: 0.968165] [G loss: 1.469469] [D real: 0.697256] [D fake: 0.376759]
[Epoch 42/50] [Batch 899/938] [D loss: 0.896047] [G loss: 1.262222] [D real: 0.693943] [D fake: 0.340994]
[Epoch 43/50] [Batch 299/938] [D loss: 1.098878] [G loss: 1.494534] [D real: 0.622190] [D fake: 0.352108]
[Epoch 43/50] [Batch 599/938] [D loss: 1.005295] [G loss: 1.588117] [D real: 0.714370] [D fake: 0.395605]
[Epoch 43/50] [Batch 899/938] [D loss: 0.841020] [G loss: 1.947297] [D real: 0.727799] [D fake: 0.331271]
[Epoch 44/50] [Batch 299/938] [D loss: 0.967317] [G loss: 1.357352] [D real: 0.663303] [D fake: 0.335372]
[Epoch 44/50] [Batch 599/938] [D loss: 0.907341] [G loss: 1.253705] [D real: 0.699086] [D fake: 0.349763]
[Epoch 44/50] [Batch 899/938] [D loss: 0.898501] [G loss: 1.435176] [D real: 0.682133] [D fake: 0.343082]
[Epoch 45/50] [Batch 299/938] [D loss: 0.954293] [G loss: 1.709867] [D real: 0.805845] [D fake: 0.466294]
[Epoch 45/50] [Batch 599/938] [D loss: 0.932705] [G loss: 1.283069] [D real: 0.612603] [D fake: 0.245902]
[Epoch 45/50] [Batch 899/938] [D loss: 1.107646] [G loss: 1.293812] [D real: 0.543919] [D fake: 0.228042]
[Epoch 46/50] [Batch 299/938] [D loss: 0.867950] [G loss: 1.428596] [D real: 0.673719] [D fake: 0.282370]
[Epoch 46/50] [Batch 599/938] [D loss: 1.078684] [G loss: 1.340557] [D real: 0.553742] [D fake: 0.248088]
[Epoch 46/50] [Batch 899/938] [D loss: 1.066378] [G loss: 1.557076] [D real: 0.690389] [D fake: 0.396920]
[Epoch 47/50] [Batch 299/938] [D loss: 0.871799] [G loss: 1.392173] [D real: 0.674179] [D fake: 0.292265]
[Epoch 47/50] [Batch 599/938] [D loss: 0.981627] [G loss: 1.588988] [D real: 0.742561] [D fake: 0.395672]
[Epoch 47/50] [Batch 899/938] [D loss: 1.032251] [G loss: 0.972809] [D real: 0.672172] [D fake: 0.393318]
[Epoch 48/50] [Batch 299/938] [D loss: 1.006597] [G loss: 1.520477] [D real: 0.724035] [D fake: 0.427500]
[Epoch 48/50] [Batch 599/938] [D loss: 0.980703] [G loss: 1.298721] [D real: 0.687361] [D fake: 0.374263]
[Epoch 48/50] [Batch 899/938] [D loss: 1.326660] [G loss: 0.897566] [D real: 0.443015] [D fake: 0.212538]
[Epoch 49/50] [Batch 299/938] [D loss: 0.924468] [G loss: 1.261113] [D real: 0.637928] [D fake: 0.281026]
[Epoch 49/50] [Batch 599/938] [D loss: 1.018169] [G loss: 1.674027] [D real: 0.746008] [D fake: 0.449878]
[Epoch 49/50] [Batch 899/938] [D loss: 1.034223] [G loss: 1.723830] [D real: 0.720788] [D fake: 0.420542]
3.保存模型
torch.save(generator.state_dict(),'./save/generator.pth')
torch.save(discriminator.state_dict(),'./save/discriminator.pth')
在本次GAN的学习中,博弈方法的引入让我耳目一新,这种随机生成噪声再不断修正的方法感觉就像是我自己的学习一样。看着厚厚的培养方案,不知道怎么下手,只有自己先瞎走一通,才能找到自己的方向。这周先是理论学习,感觉知识点挺多的,加上开学事情又多了起来。很难抽出时间来学习,只有先大致掌握原理,下周把事情忙完了再来调试。
这是我自己去找到的一小部分关于GAN的介绍
对抗生成网络(GAN)的要点可以从以下几个方面进行详细总结:
- 基本原理:GAN由两个关键组件构成,即生成器(Generator)和判别器(Discriminator)。生成器的目标是生成尽可能真实的数据以欺骗判别器,而判别器的目标则是准确地区分出真实数据和生成数据。
- 工作流程:生成器接收一个随机噪声作为输入,并通过这个噪声生成数据。判别器则接收真实数据和生成器生成的数据,其任务是判断输入数据是真实的还是由生成器产生的。
- 训练过程:在训练过程中,生成器和判别器进行动态的“博弈过程”。生成器不断尝试提高生成数据的质量,以更好地欺骗判别器;判别器则不断提高识别能力,以更准确地分辨真实数据和生成数据。
- 模型优化:为了提高GAN的性能,研究人员会尝试各种数据增强技巧,如图像的裁切、变形、调整明暗和颜色等,以提高模型在图像分类和检测任务中的精度。
- 核心概念:GAN的核心在于通过对抗学习的方式,使得生成器能够学习到真实数据的分布,从而产生高质量的合成数据。
- 应用范围:GAN的应用非常广泛,包括但不限于图像合成、风格迁移、图像到图像的转换、药物发现、语音合成等领域。
- 挑战与改进:尽管GAN在多个领域表现出色,但在实际应用中仍面临一些挑战,如模式崩溃、训练不稳定等问题。因此,研究人员不断提出新的架构和训练策略,如条件GAN、循环GAN、Wasserstein
GAN等,以解决这些问题。- 未来发展:随着研究的深入和技术的进步,GAN有望在更多领域发挥作用,同时也可能出现更多创新的变体,以满足不断增长的应用需求。
总的来说,GAN作为一种强大的生成模型,其核心在于通过生成器和判别器的对抗学习过程,不断提升生成数据的质量。在实际应用中,GAN的优化和改进是一个持续的过程,旨在解决训练中的挑战并扩展其应用范围。
在训练模型的那段代码中我感觉是GAN最核心的地方,不断的生成和对比,还要以标签作为标准,反复去计算损失,更新参数,让判别器确信生成的图片的真实性,从而达到对抗的效果。
在长久的学习以来我一直感受到随机噪声的亮点,就像人的发展一样,我们想尽办法,去创造一个熵减的过程去获得确定性。这是我的一个很直观的感受,由于时间有限,下周我再好好说说我的学习心得。
调整生成对抗网络(GAN)的训练参数是一个复杂的过程,以下是一些有效的调整策略:
初始化:可以通过训练一个变分自编码器(VAE)并使用其解码器权重来初始化生成器(generator)。这样做可以帮助生成器在开始时就具备一定的生成能力,从而提高训练的效率和成功率。
交叉训练:在训练过程中,可以采用不对称的训练方式,即每训练一次判别器(discriminator),训练多次生成器。这样可以让生成器在早期阶段获得更多的学习机会。
修改损失函数:可以尝试使用不同的损失函数来优化生成器和判别器。例如,使用带有梯度惩罚的损失函数可以帮助稳定训练过程,防止梯度爆炸或消失。
选择数据集:选择一个合适的数据集对于训练GAN至关重要。数据集应该包含足够多样化的样本,以便训练出能够生成高质量图像的GAN模型。
调整超参数:超参数如学习率、批次大小、优化器的选择等都会影响GAN的训练。需要根据具体情况进行调整,以找到最佳的训练配置。
使用GAN变体:有许多GAN的变体,如半监督生成对抗网络(SGAN)、边界搜索生成对抗网络(BGAN)等,它们在原始GAN的基础上做了改进,可以根据实际情况选择适合的模型架构。
监控训练过程:在训练过程中,应该密切监控生成器和判别器的性能,以及生成图像的质量。这有助于及时发现问题并调整训练策略。
优先训练判别器:在某些情况下,先训练判别器直到它能够很好地区分真实和生成的样本,然后再引入生成器的训练,这种方法可以提高训练的稳定性。
使用标签翻转:在训练生成器时,可以将生成的数据当作真实数据来训练判别器,即使用标签翻转的技巧,这样可以更有效地优化生成器。
资源内容为GAN对抗神经网络的各种常用变体:具体内容包括GAN(经典)、半监督生成对抗网络(SGAN)、边界搜索生成对抗网络(BGAN)、对偶生成对抗网络(DualGAN)、辅助分类-生成对抗网络(Auxiliary
Classifier GAN)等。