PyTorch学习:GAN基础模型搭建

'''
GAN基础模型搭建
利用MNIST手写字母数据集进行基础GAN程序编写
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

# 一、加载训练数据集
# 数据归一化(-1,1),因为生成器的激活函数是tanh,数据也是(-1,1)
transform = transforms.Compose([
    # 将shape为(H,W,C)的数组或者.img转化为shape为(C,H,W)的tensor
    # 将数据转化为张量,并尽心(0,1)归一化,并且CHW的格式
    transforms.ToTensor(),
    # 将数据从(0,1)之间转化到(-1,1)之间,归一化
    transforms.Normalize(0.5, 0.5),
])
'''
transforms.Compose(): 将多个预处理依次累加在一起,每次执行transform都会依次执行其中包含的多个预处理程序
transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor
transforms.Normalize(0.5, 0.5):这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,
因为图像是灰色的只有一个通道,所以分别指定一了一个值,
如果有多个通道,需要有多个数字,
如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])
'''

# 加载数据集MNIST手写字母
train_ds = torchvision.datasets.MNIST(
    root='../data/MNIST',
    train=True,
    transform=transform,
    download=False
)
test_ds = torchvision.datasets.MNIST(
    root='../data/MNIST',
    train=False,
    transform=transform,
    download=False
)
'''
root :需要下载至地址的根目录位置
train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True
transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本
download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载
'''

train_loader = data.DataLoader(
    dataset=train_ds,
    batch_size=128,
    shuffle=True
)
test_loader = data.DataLoader(
    dataset=test_ds,
    batch_size=128,
    shuffle=True
)
'''
PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。
该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
torch.utils.data.DataLoader(object)的可用参数如下:
dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。
batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)
shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)
'''


# 二、搭建生成器模型
# 输入是长度为100的噪声(正态分布)
# 由于该数据集的CHW是(1,28,28),所以我们的输出也应该为(1,28,28)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            # 最后一层使用tanh激活,使其分布在(-1,1)之间
            nn.Tanh()
        )

    # 定义前向传播
    def forward(self, x):  # x表示长度为100的噪声输入
        x = self.main(x)
        # 转化成图片的形式
        x = x.view(-1, 28, 28)
        return x


'''
生成器结构:
输出是长度为100的噪声(正态分布随机数)
输出为(1,28,28)的图片
linear 1:100---256
linear 2: 256--512
linear 3:512--784(28*28)
reshape: 784---(1,28,28)
'''


# 三、搭建判别器模型
# 输入为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid函数激活
# BCEloss计算交叉熵
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28 * 28, 512),
            # 判别器一般使用leakyrelu函数激活
            # nn.LeakyReLU f(x):如果X>0输出0,如果X<0输出a*x,a表示一个很小的斜率,比如0.1
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            # 判别器最后输出一般使用sigmoid的激活0-1
            nn.Sigmoid()
        )

    # 定义前向传播
    def forward(self, x):  # x表示长度为28*28的图片输入
        x = x.view(-1, 28 * 28)
        x = self.main(x)
        return x


'''
输入为为(1,28,28)的图片
输出为二分类(0/1)的概率值,
linear 1:784(28*28)--512
linear 2:512--256
linear 3:256--1
reshape: 1---(0 or 1)
'''

# 四、设备的配置
# 判断设备是否可以调用GPU,如果没有GPU则使用CPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 初始化生成器和判别器,将其调用到响应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)

# 分别设置生成器和辨别器的优化函数
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)

# 设置损失函数,BCELoss计算交叉熵损失
loss_func = nn.BCELoss()


# 设置绘图函数
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    plt.close()
    for i in range(test_input.size(0)):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)  # 将输出恢复在0-1之间的数值,并进行绘图
        plt.axis('off')
    # plt.show()
    # 图像显示暂停一秒钟
    plt.pause(1)


# 测试输入,16个长度为100的随机数
test_input = torch.randn(16, 100, device=device)

# 五、GAN的训练
D_loss = []
G_loss = []

# 训练循环
for epoch in range(20):
    # 初始化损失值
    d_epoch_loss = 0
    g_epoch_loss = 0
    # 返回批次数
    batch_size = len(train_loader)

    # 对数据集进行迭代
    for step, (x, y) in enumerate(train_loader):
        # 把数据传递到设备上
        x = x.to(device)
        # x的第一位是size,获取批次的大小
        size = x.size(0)
        # 根据x的size生成对于数量且长度为100的随机数列,并传递到设备上
        noise = torch.randn(size, 100, device=device)

        # 判别器训练(包含真实图片的损失和生成图片的损失两部分),损失的构建与优化

        # 真实图片损失
        # 梯度归零
        d_optim.zero_grad()
        # 判别器输入真实图片,real_output是对真实图片的预测结果
        real_output = dis(x)
        # 判别器对于真实图片产生的损失
        d_real_loss = loss_func(
            real_output,
            # 期望真实图片的预测结果均为1
            torch.ones_like(real_output)
        )
        # 计算梯度
        d_real_loss.backward()

        # 生成图片损失
        # 使用生成器生成的图片训练判别器,期望判定为假,此时图片使用生成器计算得来的
        # 得到生成器生成的图片
        gen_x = gen(noise)
        # 喂给判别器时要截断梯度,防止更新时把生成器也更新了
        # 判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度不会再传递到生成器模型中、
        fake_output = dis(gen_x.detach())
        # 判别器对于生成器生成图片产生的损失
        d_fake_loss = loss_func(
            fake_output,
            # 期望生成图片的预测结果均为0
            torch.zeros_like(fake_output)
        )
        # 计算梯度
        d_fake_loss.backward()

        # 判别器总的损失
        d_loss = d_fake_loss + d_real_loss
        # 优化判别器
        d_optim.step()

        # 生成器训练

        # 将生成器中的梯度置零
        g_optim.zero_grad()
        # 判别器输入生成的图片
        fake_output = dis(gen_x)
        # 生成器对于判别器中输入生成器生成图片产生的损失
        g_loss = loss_func(
            fake_output,
            # 生成器的损失,对于生成器是期望生成图片的预测结果全1
            torch.ones_like(fake_output)
        )
        # 计算梯度
        g_loss.backward()
        # 优化生成器
        g_optim.step()

        # 累计每一个批次的loss
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    # 求每一epoch训练的平均损失
    with torch.no_grad():
        d_epoch_loss /= batch_size
        g_epoch_loss /= batch_size
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        # 输出当前训练的epoch次数
        print(f'epoch:{epoch + 1}')
        # 利用绘图函数gen_img_plot绘制16个测试输入的生成器生成结果
        gen_img_plot(gen, test_input)

 参考:

pytorch实现GAN网络

基础GAN实例(pytorch代码实现)

GAN代码实战和原理精讲 PyTorch代码进阶 最简明易懂的GAN生成对抗网络入门课程 使用PyTorch编写GAN实例 2021.12最新课程

  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值