from __future__ import print_function
import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
if __name__ == '__main__':
# 数据集的根目录
dataroot = "E:/images"
# 数据加载器的子进程数
workers = 2
# 训练批量大小
batch_size = 128
# S调整训练图片大小
image_size = 64
# 通道数为3
nc = 3
# 图片向量
nz = 100
# 生成器中特征映射的大小
ngf = 64
# 鉴别器中特征映射的大小
ndf = 64
# 训练次数
num_epochs = 50
# 优化器学习率
lr = 0.0002
#超参数
beta1 = 0.5
# gpu数量
ngpu = 1
# 创建数据集
dataset = dset.ImageFolder(root=dataroot,
transform=transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]))
# 创建dataloader(输向模型)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=True, num_workers=workers)
# 在gpu上运行
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
# 绘制一些训练图像
real_batch = next(iter(dataloader))
plt.figure(figsize=(10,10))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))
# 在netG和netD上调用自定义权重初始化
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
# 生成器代码
class Generator(nn.Module):
def __init__(self, ngpu):
super(Generator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is Z, going into a convolution
nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf * 8),
nn.ReLU(True),
# state size. (ngf*8) x 4 x 4
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 4),
nn.ReLU(True),
# state size. (ngf*4) x 8 x 8
nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf * 2),
nn.ReLU(True),
# state size. (ngf*2) x 16 x 16
nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
# state size. (ngf) x 32 x 32
nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
# state size. (nc) x 64 x 64
)
def forward(self, input):
return self.main(input)
# 创建生成器
netG = Generator(ngpu).to(device)
# 运行gpu
if (device.type == 'cuda') and (ngpu > 1):
netG = nn.DataParallel(netG, list(range(ngpu)))
# 应用weights_init函数随机初始化所有权重
# to mean=0, stdev=0.02.
netG.apply(weights_init)
print(netG)
#判别器代码
class Discriminator(nn.Module):
def __init__(self, ngpu):
super(Discriminator, self).__init__()
self.ngpu = ngpu
self.main = nn.Sequential(
# input is (nc) x 64 x 64
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf) x 32 x 32
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 2),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*2) x 16 x 16
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 4),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*4) x 8 x 8
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf * 8),
nn.LeakyReLU(0.2, inplace=True),
# state size. (ndf*8) x 4 x 4
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, input):
return self.main(input)
# 创建判别器
netD = Discriminator(ngpu).to(device)
# 运用gpu
if (device.type == 'cuda') and (ngpu > 1):
netD = nn.DataParallel(netD, list(range(ngpu)))
# 应用weights_init函数随机初始化所有权重
# to mean=0, stdev=0.2.
netD.apply(weights_init)
print(netD)
# 初始化BCELoss函数
criterion = nn.BCELoss()
# 创建一批潜在的向量
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# 在培训中建立真假标签
real_label = 1.
fake_label = 0.
# 为G和D设置Adam优化器
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
# 记录进度的列表
img_list = []
G_losses = []
D_losses = []
iters = 0
print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
# For each batch in the dataloader
for i, data in enumerate(dataloader, 0):
#训练全真实批次
netD.zero_grad()
#批次界面文件格式
real_cpu = data[0].to(device)
b_size = real_cpu.size(0)
label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
# 通过D转发真实批次
output = netD(real_cpu).view(-1)
# 计算全实数批处理的损失
errD_real = criterion(output, label)
# 计算D在向后通过时的梯度
errD_real.backward()
D_x = output.mean().item()
#用全假批次训练
# 生成一批潜在的向量
noise = torch.randn(b_size, nz, 1, 1, device=device)
# 使用G生成伪图像批处理
fake = netG(noise)
label.fill_(fake_label)
# 用D对所有假批次进行分类
output = netD(fake.detach()).view(-1)
# 计算D在全假批次上的损失
errD_fake = criterion(output, label)
# 计算此批的梯度与以前的梯度累计(相加)
errD_fake.backward()
D_G_z1 = output.mean().item()
# 计算D的误差为假批和真批之和
errD = errD_real + errD_fake
optimizerD.step()
netG.zero_grad()
label.fill_(real_label) # 假标签是真实的判别
# 通过D执行另一个全假批处理的前向传递
output = netD(fake).view(-1)
# 根据这个输出计算G的损失
errG = criterion(output, label)
# 计算G的梯度
errG.backward()
D_G_z2 = output.mean().item()
optimizerG.step()
# 输出培训数据
if i % 50 == 0:
print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
% (epoch, num_epochs, i, len(dataloader),
errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
# 保存损失以备以后绘图
G_losses.append(errG.item())
D_losses.append(errD.item())
# 通过将G的输出保存在fixed_noise上
if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
with torch.no_grad():
fake = netG(fixed_noise).detach().cpu()
img_list.append(vutils.make_grid(fake, padding=3, normalize=True))
iters += 1
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
# 从数据加载器中获取一批真实图像
real_batch = next(iter(dataloader))
# 绘制真实图像
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=3, normalize=True).cpu(),(1,2,0)))
# 画出上个时代的假图像
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()
使用Pytorch通过GANs生成对抗网络来生成图像
于 2023-02-16 19:55:58 首次发布
该代码段展示了如何在PyTorch中设置和训练一个生成对抗网络(GAN),用于生成图像。它包括了数据加载、网络结构定义(生成器Generator和判别器Discriminator)、损失函数、优化器以及训练循环。训练过程中,同时更新生成器和判别器的权重以达到平衡。
摘要由CSDN通过智能技术生成