PyTorch学习:GAN基础模型搭建

'''
GAN基础模型搭建
利用MNIST手写字母数据集进行基础GAN程序编写
'''

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision
import torchvision.transforms as transforms

# 一、加载训练数据集
# 数据归一化(-1,1),因为生成器的激活函数是tanh,数据也是(-1,1)
transform = transforms.Compose([
    # 将shape为(H,W,C)的数组或者.img转化为shape为(C,H,W)的tensor
    # 将数据转化为张量,并尽心(0,1)归一化,并且CHW的格式
    transforms.ToTensor(),
    # 将数据从(0,1)之间转化到(-1,1)之间,归一化
    transforms.Normalize(0.5, 0.5),
])
'''
transforms.Compose(): 将多个预处理依次累加在一起,每次执行transform都会依次执行其中包含的多个预处理程序
transforms.ToTensor():在做数据归一化之前必须要把PIL Image转成Tensor
transforms.Normalize(0.5, 0.5):这里的两个0.5分别表示对张量进行归一化的 全局平均值和方差,
因为图像是灰色的只有一个通道,所以分别指定一了一个值,
如果有多个通道,需要有多个数字,
如3个通道,就应该是Normalize([m1, m2, m3], [n1, n2, n3])
'''

# 加载数据集MNIST手写字母
train_ds = torchvision.datasets.MNIST(
    root='../data/MNIST',
    train=True,
    transform=transform,
    download=False
)
test_ds = torchvision.datasets.MNIST(
    root='../data/MNIST',
    train=False,
    transform=transform,
    download=False
)
'''
root :需要下载至地址的根目录位置
train:如果是True, 下载训练集 trainin.pt; 如果是False,下载测试集 test.pt; 默认是True
transform:一系列作用在PIL图片上的转换操作,返回一个转换后的版本
download:是否下载到 root指定的位置,如果指定的root位置已经存在该数据集,则不再下载
'''

train_loader = data.DataLoader(
    dataset=train_ds,
    batch_size=128,
    shuffle=True
)
test_loader = data.DataLoader(
    dataset=test_ds,
    batch_size=128,
    shuffle=True
)
'''
PyTorch中数据读取的一个重要接口是 torch.utils.data.DataLoader。
该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch_size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
torch.utils.data.DataLoader(object)的可用参数如下:
dataset(Dataset): 数据读取接口,该输出是torch.utils.data.Dataset类的对象(或者继承自该类的自定义类的对象)。
batch_size (int, optional): 批训练数据量的大小,根据具体情况设置即可。一般为2的N次方(默认:1)
shuffle (bool, optional):是否打乱数据,一般在训练数据中会采用。(默认:False)
'''


# 二、搭建生成器模型
# 输入是长度为100的噪声(正态分布)
# 由于该数据集的CHW是(1,28,28),所以我们的输出也应该为(1,28,28)
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28 * 28),
            # 最后一层使用tanh激活,使其分布在(-1,1)之间
            nn.Tanh()
        )

    # 定义前向传播
    def forward(self, x):  # x表示长度为100的噪声输入
        x = self.main(x)
        # 转化成图片的形式
        x = x.view(-1, 28, 28)
        return x


'''
生成器结构:
输出是长度为100的噪声(正态分布随机数)
输出为(1,28,28)的图片
linear 1:100---256
linear 2: 256--512
linear 3:512--784(28*28)
reshape: 784---(1,28,28)
'''


# 三、搭建判别器模型
# 输入为(1,28,28)的图片,输出为二分类的概率值,输出使用sigmoid函数激活
# BCEloss计算交叉熵
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Linear(28 * 28, 512),
            # 判别器一般使用leakyrelu函数激活
            # nn.LeakyReLU f(x):如果X>0输出0,如果X<0输出a*x,a表示一个很小的斜率,比如0.1
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.LeakyReLU(),
            nn.Linear(256, 1),
            # 判别器最后输出一般使用sigmoid的激活0-1
            nn.Sigmoid()
        )

    # 定义前向传播
    def forward(self, x):  # x表示长度为28*28的图片输入
        x = x.view(-1, 28 * 28)
        x = self.main(x)
        return x


'''
输入为为(1,28,28)的图片
输出为二分类(0/1)的概率值,
linear 1:784(28*28)--512
linear 2:512--256
linear 3:256--1
reshape: 1---(0 or 1)
'''

# 四、设备的配置
# 判断设备是否可以调用GPU,如果没有GPU则使用CPU进行计算
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# 初始化生成器和判别器,将其调用到响应的设备上
gen = Generator().to(device)
dis = Discriminator().to(device)

# 分别设置生成器和辨别器的优化函数
d_optim = torch.optim.Adam(dis.parameters(), lr=1e-4)
g_optim = torch.optim.Adam(gen.parameters(), lr=1e-4)

# 设置损失函数,BCELoss计算交叉熵损失
loss_func = nn.BCELoss()


# 设置绘图函数
def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    plt.close()
    for i in range(test_input.size(0)):
        plt.subplot(4, 4, i + 1)
        plt.imshow((prediction[i] + 1) / 2)  # 将输出恢复在0-1之间的数值,并进行绘图
        plt.axis('off')
    # plt.show()
    # 图像显示暂停一秒钟
    plt.pause(1)


# 测试输入,16个长度为100的随机数
test_input = torch.randn(16, 100, device=device)

# 五、GAN的训练
D_loss = []
G_loss = []

# 训练循环
for epoch in range(20):
    # 初始化损失值
    d_epoch_loss = 0
    g_epoch_loss = 0
    # 返回批次数
    batch_size = len(train_loader)

    # 对数据集进行迭代
    for step, (x, y) in enumerate(train_loader):
        # 把数据传递到设备上
        x = x.to(device)
        # x的第一位是size,获取批次的大小
        size = x.size(0)
        # 根据x的size生成对于数量且长度为100的随机数列,并传递到设备上
        noise = torch.randn(size, 100, device=device)

        # 判别器训练(包含真实图片的损失和生成图片的损失两部分),损失的构建与优化

        # 真实图片损失
        # 梯度归零
        d_optim.zero_grad()
        # 判别器输入真实图片,real_output是对真实图片的预测结果
        real_output = dis(x)
        # 判别器对于真实图片产生的损失
        d_real_loss = loss_func(
            real_output,
            # 期望真实图片的预测结果均为1
            torch.ones_like(real_output)
        )
        # 计算梯度
        d_real_loss.backward()

        # 生成图片损失
        # 使用生成器生成的图片训练判别器,期望判定为假,此时图片使用生成器计算得来的
        # 得到生成器生成的图片
        gen_x = gen(noise)
        # 喂给判别器时要截断梯度,防止更新时把生成器也更新了
        # 判别器输入生成的图片,fake_output对生成图片的预测;detach会截断梯度,梯度不会再传递到生成器模型中、
        fake_output = dis(gen_x.detach())
        # 判别器对于生成器生成图片产生的损失
        d_fake_loss = loss_func(
            fake_output,
            # 期望生成图片的预测结果均为0
            torch.zeros_like(fake_output)
        )
        # 计算梯度
        d_fake_loss.backward()

        # 判别器总的损失
        d_loss = d_fake_loss + d_real_loss
        # 优化判别器
        d_optim.step()

        # 生成器训练

        # 将生成器中的梯度置零
        g_optim.zero_grad()
        # 判别器输入生成的图片
        fake_output = dis(gen_x)
        # 生成器对于判别器中输入生成器生成图片产生的损失
        g_loss = loss_func(
            fake_output,
            # 生成器的损失,对于生成器是期望生成图片的预测结果全1
            torch.ones_like(fake_output)
        )
        # 计算梯度
        g_loss.backward()
        # 优化生成器
        g_optim.step()

        # 累计每一个批次的loss
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

    # 求每一epoch训练的平均损失
    with torch.no_grad():
        d_epoch_loss /= batch_size
        g_epoch_loss /= batch_size
        D_loss.append(d_epoch_loss)
        G_loss.append(g_epoch_loss)
        # 输出当前训练的epoch次数
        print(f'epoch:{epoch + 1}')
        # 利用绘图函数gen_img_plot绘制16个测试输入的生成器生成结果
        gen_img_plot(gen, test_input)

 参考:

pytorch实现GAN网络

基础GAN实例(pytorch代码实现)

GAN代码实战和原理精讲 PyTorch代码进阶 最简明易懂的GAN生成对抗网络入门课程 使用PyTorch编写GAN实例 2021.12最新课程

PyTorch中实现GAN网络通常需要定义两个模型:生成器和判别器。以下是一个简单的GAN网络的示例代码: ```python import torch import torch.nn as nn # 定义生成器模型 class Generator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Generator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.relu = nn.ReLU() self.tanh = nn.Tanh() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.tanh(self.fc3(x)) return x # 定义判别器模型 class Discriminator(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(Discriminator, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.fc2 = nn.Linear(hidden_size, hidden_size) self.fc3 = nn.Linear(hidden_size, output_size) self.sigmoid = nn.Sigmoid() def forward(self, x): x = self.sigmoid(self.fc1(x)) x = self.sigmoid(self.fc2(x)) x = self.sigmoid(self.fc3(x)) return x # 定义损失函数和优化器 criterion = nn.BCELoss() generator = Generator(input_size, hidden_size, output_size) discriminator = Discriminator(input_size, hidden_size, output_size) g_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate) d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate) # 训练GAN网络 for epoch in range(num_epochs): for i, images in enumerate(train_loader): # 训练生成器 z = torch.randn(batch_size, input_size) fake_images = generator(z) d_fake = discriminator(fake_images) g_loss = criterion(d_fake, torch.ones_like(d_fake)) g_optimizer.zero_grad() g_loss.backward() g_optimizer.step() # 训练判别器 real_images = images.view(-1, input_size) d_real = discriminator(real_images) d_loss_real = criterion(d_real, torch.ones_like(d_real)) d_loss_fake = criterion(d_fake, torch.zeros_like(d_fake)) d_loss = d_loss_real + d_loss_fake d_optimizer.zero_grad() d_loss.backward() d_optimizer.step() ``` 在训练过程中,生成器网络将随机噪声作为输入,生成虚假图像,而判别器网络将真实图像和虚假图像作为输入,尝试区分它们的真伪。损失函数的目标是最小化生成器输出的虚假图像与真实图像之间的差异,并最大化判别器对真实和虚假图像的分类准确性。通过交替训练生成器和判别器,模型将逐渐学会生成更真实的图像。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值