使用PyTorch实现GAN(生成对抗网络)的详细步骤

使用PyTorch实现GAN(生成对抗网络)

在这篇博客中,我们将探讨如何使用PyTorch实现一个简单的生成对抗网络(GAN),用于生成手写数字图像(MNIST数据集)。我们将详细介绍代码的每一个部分,并确保所有注释信息都被保留。

1. 导入必要的库

以下是我们用到的主要库:

import torchvision
import torch
from matplotlib import pyplot as plt
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import os

这些库将帮助我们处理数据、构建模型、训练网络以及可视化生成的图像。

2. 数据准备

我们需要首先对训练数据进行处理。下面是MNIST数据集的准备过程:

# 数据准备
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
train_dataset = torchvision.datasets.MNIST("./dataset", train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True, drop_last=True)

我们使用了transforms.Compose将数据转换为张量并进行归一化,以便模型能够更好地进行学习。

3. 生成器和判别器的定义

接下来,我们定义生成器(Generator)和判别器(Discriminator)。

生成器:

# 生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 128),
            nn.Linear(128, 256),
            nn.Linear(256, 512),
            nn.Linear(512, 28 * 28),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

生成器将随机噪声输入转化为28x28的图像。

判别器:

# 判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 512),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = x.view(-1, 28, 28)
        return self.model(x)

判别器负责区分生成的图像和真实图像。

4. 初始化模型,优化器及损失计算函数

接下来,我们初始化模型、优化器以及损失函数。

# 初始化模型,优化器及损失计算函数
writer = SummaryWriter("./p1log")
device = "cuda" if torch.cuda.is_available() else "cpu"
gen = Generator().to(device)
dis = Discriminator().to(device)
# 优化器
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-3)
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-3)
# 损失函数
loss_fn = nn.BCELoss()
epochs = 20

我们使用Adam优化器和二元交叉熵损失函数(BCELoss)来计算生成器和判别器的损失。

5. 训练过程

我们开始训练循环,并对生成器和判别器进行优化。

# 开始循环
for epoch in range(epochs):
    for step, (imgs, _) in enumerate(train_dataloader):
        imgs = imgs.to(device)
        size = imgs.size(0)  # 返回第0纬度的大小,也就是batch_size=64
        random_noise = torch.randn(size, 100, device=device)  # 生成64个,大小为100特征值的噪音

        # 鉴别器的优化
        d_optim.zero_grad()
        # 对真实的图片,希望判断为1
        real_output = dis(imgs)
        d_real_loss = loss_fn(real_output,
                               torch.ones_like(real_output))  # 达到鉴别器在真实图片上的损失

        # 对生成器生成的图片,希望对生成器生成的全部判断为0
        gen_img = gen(random_noise)
        fake_output = dis(gen_img).detach()  # 截断梯度,希望判断为0
        d_fake_loss = loss_fn(fake_output,
                               torch.zeros_like(fake_output))  # 得到鉴别器在生成图片上的损失
        # 鉴别器的损失等于鉴别真实图片和鉴别生成图片的损失和
        d_loss = d_fake_loss + d_real_loss
        d_loss.backward()
        d_optim.step()

        # 生成器的优化
        g_optim.zero_grad()
        gen_img = gen(random_noise)
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output,
                         torch.ones_like(fake_output))
        g_loss.backward()
        g_optim.step()
        # 将损失记录到tensorboard
        writer.add_scalar('D_Loss_epoch:{}'.format(epoch + 1), d_loss.item(), epoch * len(train_dataloader) + step)
        writer.add_scalar('G_Loss_epoch:{}'.format(epoch + 1), g_loss.item(), epoch * len(train_dataloader) + step)
        # 每一百次进行一次打印
        if step % 100 == 0:
            print(f'Epoch [{epoch + 1}/{epochs}], Step [{step}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')
            print("Epoch:{},生成器的损失:{},鉴别器的损失:{}".format(epoch + 1, g_loss, d_loss))

在每个训练步骤中,我们首先优化判别器,之后再优化生成器。我们记录损失值以用于后期分析和可视化。

6. 生成和保存图像

在每个epoch结束时,我们将生成的图像保存到文件中。

# 每个 epoch 保存生成的图片
with torch.no_grad():
    gen.eval()
    test_noise = torch.randn(64, 100, device=device)
    i = 0
    generated_images = gen(test_noise).view(-1, 1, 28, 28)
    generated_images = (generated_images + 1) / 2  # 将[-1, 1]转换到[0, 1]
    # 绘图
    grid = torchvision.utils.make_grid(generated_images, nrow=8)  # 将多个生成的图像组合成一个网格图像
    plt.figure(figsize=(8, 8))  # 创建一个新的图形对象,并设置图形的大小为8x8英寸
    plt.imshow(grid.cpu().numpy().transpose((1, 2, 0)))  # 转换为Matplotlib所需的格式
    plt.axis('off')  # 关闭坐标轴和刻度
    if not os.path.exists(f'./images/epoch_{epoch + 1}'):
        os.makedirs(f'./images/epoch_{epoch + 1}')
    plt.savefig(f"./images/epoch_{epoch + 1}/{i}.png")  # 保存生成的图片
    i += 1
    plt.close()

这个部分的代码生成了一些假图片,并将结果保存到指定的目录中,以便后续查看。

总代码

import torchvision
import torch
from matplotlib import pyplot as plt
from torch import nn, Tensor
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import os

#数据准备
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize(0.5, 0.5)])
train_dataset=torchvision.datasets.MNIST("./dataset",train=True,transform=transform,download=True)
train_dataloader=DataLoader(train_dataset,batch_size=64,shuffle=True,drop_last=True)

#生成器
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Linear(100,128),
            nn.Linear(128,256),
            nn.Linear(256,512),
            nn.Linear(512,28*28),
            nn.Tanh()
        )

    def forward(self,x):
        return self.model(x)
#判别器
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model=nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28,512),
            nn.LeakyReLU(),
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid()
        )

    def forward(self,x):
        x=x.view(-1,28,28)
        return self.model(x)

#初始化模型,优化器及损失计算函数
#用BCEloss计算交叉熵损失
writer=SummaryWriter("./p1log")
device="cuda" if torch.cuda.is_available() else "cpu"
gen=Generator().to(device)
dis=Discriminator().to(device)
#优化器
g_optim=torch.optim.Adam(gen.parameters(),lr=1e-3)
d_optim=torch.optim.Adam(dis.parameters(),lr=1e-3)
#损失函数
loss_fn=nn.BCELoss()
epochs=20

#开始循环
for epoch in range(epochs):
    for step,(imgs,_) in enumerate(train_dataloader):
        imgs = imgs.to(device)
        size = imgs.size(0)#返回第0纬度的大小,也就是batch_size=64
        random_noise = torch.randn(size, 100, device=device)#生成64个,大小为100特征值的噪音

        # 鉴别器的优化
        d_optim.zero_grad()
        # 对真实的图片,希望判断为1
        real_output = dis(imgs)
        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output))  # 达到鉴别器在真实图片上的损失

        # 对gen生成图片,希望对gen生成的全部判断为0
        gen_img = gen(random_noise)
        fake_output = dis(gen_img).detach()  # 截断梯度,希望判断为0
        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output))  # 得到鉴别器在生成图片上的损失
        # 鉴别器的损失等于鉴别真实图片和鉴别生成图片的损失和
        d_loss = d_fake_loss + d_real_loss
        d_loss.backward()
        d_optim.step()

        # 生成器的优化
        g_optim.zero_grad()
        gen_img=gen(random_noise)
        fake_output = dis(gen_img)
        g_loss = loss_fn(fake_output,
                         torch.ones_like(fake_output))
        g_loss.backward()
        g_optim.step()
        #将损失记录到tensorboard
        writer.add_scalar('D_Loss_epoch:{}'.format(epoch+1), d_loss.item(), epoch * len(train_dataloader) + step)
        writer.add_scalar('G_Loss_epoch:{}'.format(epoch+1), g_loss.item(), epoch * len(train_dataloader) + step)
        #每一百次进行一次打印
        if step % 100 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Step [{step}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}')
            print("Epoch:{},生成器的损失:{},鉴别器的损失:{}".format(epoch+1,g_loss,d_loss))
 # 每个 epoch 保存生成的图片
    with torch.no_grad():
        gen.eval()
        test_noise = torch.randn(64, 100, device=device)
        i=0
        generated_images = gen(test_noise).view(-1, 1, 28, 28)
        generated_images = (generated_images + 1) / 2  # 将 [-1, 1] 转换到 [0, 1]
        #绘图
        grid = torchvision.utils.make_grid(generated_images, nrow=8)#torchvision.utils.make_grid 将多个生成的图像(generated_images)组合成一个网格图像,nrow=8 表示每一行显示 8 张图像。
        plt.figure(figsize=(8, 8))#创建一个新的图形对象,并设置图形的大小为 8x8 英寸。
        plt.imshow(grid.cpu().numpy().transpose((1, 2, 0))) #作用是将图像的数据格式从 PyTorch 的默认格式 (C, H, W) 转换为 Matplotlib 所需的格式 (H, W, C),以便正确显示图像。
        plt.axis('off')#关闭坐标轴和刻度
        if not os.path.exists(f'./images/epoch_{epoch+1}'):
            os.makedirs(f'./images/epoch_{epoch+1}')
        plt.savefig(f"./images/epoch_{epoch+1}/{i}.png")  # 保存生成的图片
        i+=1
        plt.close()
好的,下面是pytorch实现GAN的步骤: 1.导入必要的库 ```python import torch import torch.nn as nn import torch.optim as optim import torchvision.datasets as dset import torchvision.transforms as transforms from torch.utils.data import DataLoader import os import numpy as np import torchvision.utils as vutils ``` 2.定义生成器和鉴别器 ```python class Generator(nn.Module): def __init__(self, ngf, nz, nc): super(Generator, self).__init__() self.main = nn.Sequential( nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False), nn.Tanh() ) def forward(self, input): return self.main(input) class Discriminator(nn.Module): def __init__(self, ndf, nc): super(Discriminator, self).__init__() self.main = nn.Sequential( nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), nn.Sigmoid() ) def forward(self, input): return self.main(input).view(-1, 1).squeeze(1) ``` 3.定义超参数 ```python # 超参数 batch_size = 64 image_size = 64 nz = 100 ngf = 64 ndf = 64 num_epochs = 50 lr = 0.0002 beta1 = 0.5 ngpu = 1 ``` 4.准备数据集 ```python # 图像处理 transform = transforms.Compose([ transforms.Resize(image_size), transforms.CenterCrop(image_size), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) # 数据集 dataset = dset.ImageFolder(root='./data', transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2) ``` 5.定义优化器和损失函数 ```python # 设备 device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu") # 初始化生成器和鉴别器 netG = Generator(ngf, nz, 3).to(device) netD = Discriminator(ndf, 3).to(device) # 初始化权重 netG.apply(weights_init) netD.apply(weights_init) # 定义损失函数和优化器 criterion = nn.BCELoss() optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999)) ``` 6.训练模型 ```python # 真实标签 real_label = 1. # 假标签 fake_label = 0. # 训练 for epoch in range(num_epochs): for i, data in enumerate(dataloader, 0): # 判别器的训练 netD.zero_grad() real_cpu = data[0].to(device) b_size = real_cpu.size(0) label = torch.full((b_size,), real_label, device=device) output = netD(real_cpu) errD_real = criterion(output, label) errD_real.backward() D_x = output.mean().item() noise = torch.randn(b_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(fake_label) output = netD(fake.detach()) errD_fake = criterion(output, label) errD_fake.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizerD.step() # 生成器的训练 netG.zero_grad() label.fill_(real_label) output = netD(fake) errG = criterion(output, label) errG.backward() D_G_z2 = output.mean().item() optimizerG.step() # 输出训练状态 if i % 50 == 0: print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) # 保存生成器的输出 if (epoch == 0) and (i == 0): vutils.save_image(real_cpu, '%s/real_samples.png' % "./results", normalize=True) if i % 100 == 0: with torch.no_grad(): fake = netG(fixed_noise).detach().cpu() vutils.save_image(fake, '%s/fake_samples_epoch_%03d.png' % ("./results", epoch), normalize=True) ``` 以上就是pytorch实现GAN的步骤,其中还包括了权重的初始化、训练状态的输出、保存生成器的输出等。这里只是一个简单的示例,实际使用时还需要根据具体问题进行相应的调整和优化。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值