PyTorch学习:GAN基础模型搭建(二)

'''
GAN基础模型搭建
利用faces动漫头像数据集进行基础GAN程序编写
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
import tqdm

# 一、加载训练数据集
# 数据归一化(-1,1),因为生成器的激活函数是tanh,数据也是(-1,1)
transform = transforms.Compose([
    # 将图片转换为96*96的尺寸
    transforms.Resize(96),
    # 将图片中心截取为96*96的尺寸
    transforms.CenterCrop(96),
    # 将shape为(H,W,C)的数组或者.img转化为shape为(C,H,W)的tensor
    # 将数据转化为张量,并尽心(0,1)归一化,并且CHW的格式
    transforms.ToTensor(),
    # 将数据从(0,1)之间转化到(-1,1)之间,归一化
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
'''
transforms.Compose(): 将多个预处理依次累加在一起,每次执行transform都会依次执行其中包含的多个预处理程序
transforms.Resize():调整图片的尺寸
transforms.CenterCrop():以图片中心为基准,按规定大小截取图片
transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor
transforms.Normalize(0.5, 0.5):这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,
因为图像是灰色的只有一个通道,所以分别指定一了一个值,
如果有多个通道,需要有多个数字,
如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])
'''

# 加载数据集faces动漫头像
train_data_dir = '../data/faces' # 图像数据储存路径
train_ds = ImageFolder(
    train_data_dir,
    transform=transform
)

test_ds = ImageFolder(
    train_data_dir,
    transform=transform
)

'''
ImageFolder:设置图片的所在文件夹和图片的调整函数
'''

train_data_loader = data.DataLoader(
    dataset=train_ds,
    batch_size=16,
    shuffle=True,
    num_workers=0
)
test_loader = data.DataLoader(
    dataset=test_ds,
    batch_size=16,
    shuffle=True
)
'''
PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。
该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
torch.utils.data.DataLoader(object)的可用参数如下:
dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。
batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)
shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)
'''


# 二、搭建生成器模型
# 输入是长度为100的噪声(正态分布)
# 由于该数据集的CHW是(3,96,96),所以我们的输出也应该为(3,96,96)
class NetG(nn.Module):
    """
    生成器定义
    """

    def __init__(self):
        super(NetG, self).__init__()
        # 生成器feature map数
        ngf = 64

        self.main = nn.Sequential(
            # 输入是一个nz维度的噪声,我们可以认为它是一个1*1*nz的feature map
            nn.ConvTranspose2d(100, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf*8) x 4 x 4

            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*4) x 8 x 8

            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            # 上一步的输出形状: (ngf*2) x 16 x 16

            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            # 上一步的输出形状:(ngf) x 32 x 32

            nn.ConvTranspose2d(ngf, 3, 5, 3, 1, bias=False),
            nn.Tanh()  # 输出范围 -1~1 故而采用Tanh
            # 输出形状:3 x 96 x 96
        )

    def forward(self, input):
        return self.main(input)
'''
生成器结构:
输出是长度为100的噪声(正态分布随机数)
输出为(3,96,96)的图片
'''


# 三、搭建判别器模型
# 输入为(3,96,96)的图片,输出为二分类的概率值,输出使用sigmoid函数激活
# BCEloss计算交叉熵
class NetD(nn.Module):
    """
    判别器定义
    """

    def __init__(self):
        super(NetD, self).__init__()
        # 判别器feature map数
        ndf = 64
        self.main = nn.Sequential(
            # 输入 3 x 96 x 96
            nn.Conv2d(3, ndf, 5, 3, 1, bias=False),
            # 判别器一般使用leakyrelu函数激活
            # nn.LeakyReLU f(x):如果X>0输出0,如果X<0输出a*x,a表示一个很小的斜率,比如0.1
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf) x 32 x 32

            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*2) x 16 x 16

            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*4) x 8 x 8

            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # 输出 (ndf*8) x 4 x 4

            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            # 判别器最后输出一般使用sigmoid的激活0-1
            nn.Sigmoid()
            # 输出一个数(概率)
        )

    def forward(self, input):
        return self.main(input).view(-1)
'''
输入为为(3,96,96)的图片
输出为二分类(0/1)的概率值,
'''

# 四、设备的配置
# 判断设备是否可以调用GPU,如果没有GPU则使用CPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 初始化生成器和判别器,将其调用到响应的设备上
gen = NetG().to(device)
dis = NetD().to(device)

# 分别设置生成器和辨别器的优化函数
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)

# 设置损失函数,BCELoss计算交叉熵损失
loss_func = nn.BCELoss()


# 设置绘图函数
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    # print(prediction.shape)
    fig = plt.figure(figsize=(4, 4))
    plt.close()
    for i in range(test_input.size(0)):
        plt.subplot(4, 4, i + 1)
        # 将输出恢复在0-1之间的数值,并进行绘图
        # plt.imshow((prediction[i] + 1) / 2)
        '''
        由于使用matplotlib显示彩色图像需要数据的维度为[width, height, channel],就是96 * 96 * 3
        而tensor的维度为 3 * 224 * 224,需要用.T对Tensor进行维度顺序颠倒
        才能保证画图顺利
        '''
        plt.imshow(((prediction[i] + 1) / 2).T)
        plt.axis('off')
    # plt.show()
    # 图像显示暂停一秒钟
    plt.pause(1)


# 测试输入,16个长度为100的随机数
test_input = torch.randn(16, 100, 1, 1, device=device)

# 五、GAN的训练
D_loss = []
G_loss = []

# 训练循环
#  Generator 训练1次,Discriminator训练5次 总迭代100次
for epoch in range(400):
    # 初始化损失值
    d_epoch_loss = 0
    g_epoch_loss = 0
    # 返回批次数
    batch_size = len(train_data_loader)

    # 对数据集进行迭代
    for step, (x, y) in tqdm.tqdm(enumerate(train_data_loader)):
        # 把数据传递到设备上
        x = x.to(device)
        # x的第一位是size,获取批次的大小
        size = x.size(0)
        # 根据x的size生成对于数量且长度为100的随机数列,并传递到设备上
        noise = torch.randn(size, 100, 1, 1, device=device)

        # 判别器训练(包含真实图片的损失和生成图片的损失两部分),损失的构建与优化
        if step % 1 == 0:
            # 真实图片损失
            # 梯度归零
            d_optim.zero_grad()
            # 判别器输入真实图片,real_output是对真实图片的预测结果
            real_output = dis(x)
            # 判别器对于真实图片产生的损失
            d_real_loss = loss_func(
                real_output,
                # 期望真实图片的预测结果均为1
                torch.ones_like(real_output)
            )
            # 计算梯度
            d_real_loss.backward()

            # 生成图片损失
            # 使用生成器生成的图片训练判别器,期望判定为假,此时图片使用生成器计算得来的
            # 得到生成器生成的图片
            gen_x = gen(noise)
            # 喂给判别器时要截断梯度,防止更新时把生成器也更新了
            # 判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度不会再传递到生成器模型中、
            fake_output = dis(gen_x.detach())
            # 判别器对于生成器生成图片产生的损失
            d_fake_loss = loss_func(
                fake_output,
                # 期望生成图片的预测结果均为0
                torch.zeros_like(fake_output)
            )
            # 计算梯度
            d_fake_loss.backward()

            # 判别器总的损失
            d_loss = d_fake_loss + d_real_loss
            # 优化判别器
            d_optim.step()


        # 生成器训练
        if step % 5 == 0:
            # 将生成器中的梯度置零
            g_optim.zero_grad()
            # 判别器输入生成的图片
            fake_output = dis(gen_x)
            # 生成器对于判别器中输入生成器生成图片产生的损失
            g_loss = loss_func(
                fake_output,
                # 生成器的损失,对于生成器是期望生成图片的预测结果全1
                torch.ones_like(fake_output)
            )
            # 计算梯度
            g_loss.backward()
            # 优化生成器
            g_optim.step()

        # 累计每一个批次的loss
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    # 求每一epoch训练的平均损失
    with torch.no_grad():
        d_epoch_loss /= batch_size
        g_epoch_loss /= batch_size
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        # 输出当前训练的epoch次数
        print(f'epoch:{epoch + 1}')
        # 利用绘图函数gen_img_plot绘制16个测试输入的生成器生成结果
        gen_img_plot(gen, test_input)

参考

电子工业出版社的《深度学习框架PyTorch:入门与实践》第七章的配套代码

pytorch实现GAN网络

Pytorch实现GAN 生成动漫头像

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值