使用Pytorch通过GANs生成对抗网络来生成图像

该代码段展示了如何在PyTorch中设置和训练一个生成对抗网络(GAN),用于生成图像。它包括了数据加载、网络结构定义(生成器Generator和判别器Discriminator)、损失函数、优化器以及训练循环。训练过程中,同时更新生成器和判别器的权重以达到平衡。
摘要由CSDN通过智能技术生成
from __future__ import print_function

import os
os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

if __name__ == '__main__':

    # 数据集的根目录
    dataroot = "E:/images"

    # 数据加载器的子进程数
    workers = 2

    # 训练批量大小
    batch_size = 128

    # S调整训练图片大小
    image_size = 64

    # 通道数为3
    nc = 3

    # 图片向量
    nz = 100

    # 生成器中特征映射的大小
    ngf = 64

    # 鉴别器中特征映射的大小
    ndf = 64

    # 训练次数
    num_epochs = 50

    # 优化器学习率
    lr = 0.0002

    #超参数
    beta1 = 0.5

    # gpu数量
    ngpu = 1

    # 创建数据集
    dataset = dset.ImageFolder(root=dataroot,
                               transform=transforms.Compose([
                                   transforms.Resize(image_size),
                                   transforms.CenterCrop(image_size),
                                   transforms.ToTensor(),
                                   transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                               ]))
    # 创建dataloader(输向模型)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             shuffle=True, num_workers=workers)

    # 在gpu上运行
    device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

    # 绘制一些训练图像
    real_batch = next(iter(dataloader))
    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.title("Training Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))

    # 在netG和netD上调用自定义权重初始化
    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    # 生成器代码

    class Generator(nn.Module):
        def __init__(self, ngpu):
            super(Generator, self).__init__()
            self.ngpu = ngpu
            self.main = nn.Sequential(
                # input is Z, going into a convolution
                nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
                nn.BatchNorm2d(ngf * 8),
                nn.ReLU(True),
                # state size. (ngf*8) x 4 x 4
                nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 4),
                nn.ReLU(True),
                # state size. (ngf*4) x 8 x 8
                nn.ConvTranspose2d( ngf * 4, ngf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf * 2),
                nn.ReLU(True),
                # state size. (ngf*2) x 16 x 16
                nn.ConvTranspose2d( ngf * 2, ngf, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ngf),
                nn.ReLU(True),
                # state size. (ngf) x 32 x 32
                nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
                nn.Tanh()
                # state size. (nc) x 64 x 64
            )

        def forward(self, input):
            return self.main(input)

    # 创建生成器
    netG = Generator(ngpu).to(device)

    # 运行gpu
    if (device.type == 'cuda') and (ngpu > 1):
        netG = nn.DataParallel(netG, list(range(ngpu)))

    # 应用weights_init函数随机初始化所有权重
    #  to mean=0, stdev=0.02.
    netG.apply(weights_init)


    print(netG)

#判别器代码
    class Discriminator(nn.Module):
        def __init__(self, ngpu):
            super(Discriminator, self).__init__()
            self.ngpu = ngpu
            self.main = nn.Sequential(
                # input is (nc) x 64 x 64
                nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf) x 32 x 32
                nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 2),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*2) x 16 x 16
                nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 4),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*4) x 8 x 8
                nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
                nn.BatchNorm2d(ndf * 8),
                nn.LeakyReLU(0.2, inplace=True),
                # state size. (ndf*8) x 4 x 4
                nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
                nn.Sigmoid()
            )

        def forward(self, input):
            return self.main(input)

    # 创建判别器
    netD = Discriminator(ngpu).to(device)

    # 运用gpu
    if (device.type == 'cuda') and (ngpu > 1):
        netD = nn.DataParallel(netD, list(range(ngpu)))

    # 应用weights_init函数随机初始化所有权重
    #  to mean=0, stdev=0.2.
    netD.apply(weights_init)

    print(netD)

    # 初始化BCELoss函数
    criterion = nn.BCELoss()

    # 创建一批潜在的向量
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)

    # 在培训中建立真假标签
    real_label = 1.
    fake_label = 0.

    # 为G和D设置Adam优化器
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))


    # 记录进度的列表
    img_list = []
    G_losses = []
    D_losses = []
    iters = 0

    print("Starting Training Loop...")
    # For each epoch
    for epoch in range(num_epochs):
        # For each batch in the dataloader
        for i, data in enumerate(dataloader, 0):


            #训练全真实批次
            netD.zero_grad()
            #批次界面文件格式
            real_cpu = data[0].to(device)
            b_size = real_cpu.size(0)
            label = torch.full((b_size,), real_label, dtype=torch.float, device=device)
            # 通过D转发真实批次
            output = netD(real_cpu).view(-1)
            # 计算全实数批处理的损失
            errD_real = criterion(output, label)
            # 计算D在向后通过时的梯度
            errD_real.backward()
            D_x = output.mean().item()

            #用全假批次训练
            # 生成一批潜在的向量
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            # 使用G生成伪图像批处理
            fake = netG(noise)
            label.fill_(fake_label)
            # 用D对所有假批次进行分类
            output = netD(fake.detach()).view(-1)
            # 计算D在全假批次上的损失
            errD_fake = criterion(output, label)
            # 计算此批的梯度与以前的梯度累计(相加)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            # 计算D的误差为假批和真批之和
            errD = errD_real + errD_fake
            optimizerD.step()


            netG.zero_grad()
            label.fill_(real_label)  # 假标签是真实的判别
            # 通过D执行另一个全假批处理的前向传递
            output = netD(fake).view(-1)
            # 根据这个输出计算G的损失
            errG = criterion(output, label)
            # 计算G的梯度
            errG.backward()
            D_G_z2 = output.mean().item()

            optimizerG.step()

            # 输出培训数据
            if i % 50 == 0:
                print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

            # 保存损失以备以后绘图
            G_losses.append(errG.item())
            D_losses.append(errD.item())

            # 通过将G的输出保存在fixed_noise上
            if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
                with torch.no_grad():
                    fake = netG(fixed_noise).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=3, normalize=True))

            iters += 1

    plt.figure(figsize=(10,5))
    plt.title("Generator and Discriminator Loss During Training")
    plt.plot(G_losses,label="G")
    plt.plot(D_losses,label="D")
    plt.xlabel("iterations")
    plt.ylabel("Loss")
    plt.legend()
    plt.show()

    # 从数据加载器中获取一批真实图像
    real_batch = next(iter(dataloader))

    # 绘制真实图像
    plt.figure(figsize=(15,15))
    plt.subplot(1,2,1)
    plt.axis("off")
    plt.title("Real Images")
    plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=3, normalize=True).cpu(),(1,2,0)))

    # 画出上个时代的假图像
    plt.subplot(1,2,2)
    plt.axis("off")
    plt.title("Fake Images")
    plt.imshow(np.transpose(img_list[-1],(1,2,0)))
    plt.show()

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Alita elessar

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值