GAN-MNIST实战


前言

使用生成对抗网络(GAN)生成手写数字的任务


一、GAN原理

基本原理是通过训练两个神经网络——生成器和判别器,并通过对抗学习的方式相互竞争和提高性能,从而生成看起来像真实样本的数据

二、使用步骤

1.引入库

代码如下(示例):

import os.path
import torch
import torchvision
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import MatplotlibDeprecationWarning
import warnings

2. matplotlib异常处理

用于取消警告显示和处理莫名其妙的报错:

warnings.filterwarnings("ignore", category=MatplotlibDeprecationWarning)

os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

3. 数据加载

#数据预处理操作
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),#将图像数据转换为PyTorch张量(Tensor)的格式
    torchvision.transforms.Normalize(0.5,0.5)#对图像进行标准化处理
])
#加载数据
real_data = torchvision.datasets.MNIST(
    root='../FGSM-MNIST-master/data',
    train=True,
    transform=transform,#0-1; channel,high,width
    download=True
)
bs = 128
real_dataloader = DataLoader(
    dataset=real_data,
    batch_size=bs,#批次大小
    shuffle=True#随机打乱
)
imgs,_=next(iter(real_dataloader))#获取图像示例
print(imgs.shape)#输出一个批次图像的数据形状 torch.Size([128, 1, 28, 28])

4. 定义生成器

简化实验,使用三层全连接网络实现

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(in_features=100, out_features=256),
            nn.ReLU(),
            nn.Linear(in_features=256, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=28*28),
            nn.Tanh() #将输入值映射到-1, 1之间
      )

    def forward(self, X):
        return self.network(X).view(-1,28,28,1)##修改形状,-1自动推理,28*28 1通道

5. 定义辨别器

与传统的ReLU函数相比,nn.LeakyReLU()能够更好地处理梯度消失的问题,因为它在负值区域有一个较小的斜率,防止了神经元"死亡"并促进了梯度向后传播。

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.network=nn.Sequential(
            nn.Linear(28*28,512),
            #nn.LeakyReLU f(x) :x>0 输出0,如果x<0,输出 a*x a表示一个很小的斜率,如0.1
            nn.LeakyReLU(),
            nn.Linear(512,256),
            nn.LeakyReLU(),
            nn.Linear(256,1),
            nn.Sigmoid()#转为概率值
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        return self.network(x)

6. 训练迁移/设置优化器

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)

d_optim = torch.optim.Adam(dis.parameters(),lr=2e-4)#辨别优化器
g_optim = torch.optim.Adam(gen.parameters(),lr=2e-4)#生成优化器
loss_fn = torch.nn.BCELoss()#二元交叉熵损失函数

7. 绘制训练图像

def gen_img_plot(model, test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4, 4))
    for i in range(prediction.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((prediction[i]+1)/2)  # 0~1之间
        plt.axis('off')
    plt.show()

8. 循环训练

D_loss=[]
G_loss=[]
count = len(real_dataloader)

#训练循环
for epoch in range(20):
    print(f"epoch {epoch + 1}\n-----------------")


    d_epoch_loss=0
    g_epoch_loss = 0

    for i, (X_real, _) in enumerate(real_dataloader):
        ###鉴别器
        X_real = X_real.to(device)

        size = X_real.size(0)

        random_noise=torch.randn(size, 100, device=device)#构造图片

        d_optim.zero_grad()#避免梯度累加

        real_output = dis(X_real)   #判别器输入真实图片,real_output对真实图片的判别结果

        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output)) #1.判别器在真实图像的损失

        d_real_loss.backward()#计算真实图像损失


        gen_img = gen(random_noise) #生成假图片

        fake_output = dis(gen_img.detach()) #判别器输入生成的图片,fake_output对生成图片的预测 优化判别器#截断梯度,不利用生成器训练辨别器

        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output)
                            )#2.得到判别器在生成图像上的损失

        d_fake_loss.backward()#计算生成图像的损失

        d_loss= d_real_loss+d_fake_loss

        d_optim.step() #参数更新操作

        ###生成器

        g_optim.zero_grad()#避免梯度累加

        fake_output=dis(gen_img)#需要记录梯度

        g_loss=loss_fn(fake_output,
                       torch.ones_like(fake_output)
                   ) #生成器损失

        g_loss.backward()
        g_optim.step() #参数更新操作
        with torch.no_grad():#不需要进行反向传播或计算梯度,只对损失进行累加
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss

        if i % 100 == 0:
            print(
                f"loss_G: {d_loss.item()}, loss_D: {g_loss.item()}, D(x): {d_real_loss.item()}, D(G(z)): {d_fake_loss.item()}")
    with torch.no_grad():
            d_epoch_loss /=count
            g_epoch_loss /=count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)

            gen_img_plot(gen, test_input)

torch.save(gen.state_dict(), 'model_my.pth')

三、训练过程示例

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述


  • 7
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值