深度学习训练营之生成对抗网络

原文链接

环境介绍

  • 语言环境:Python3.10.13
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2

GAN生成对抗网络的简单介绍

GAN并不表示某一种具体的深度学习网络,而是一种基于博弈论的神经网络,其分为GenerationDiscriminiation两个部分,目的是为了将真实的样本和人工样本进行区分,在训练过程当中GenerationDiscriminlation相互交替出现,互相博弈,当判别器Discrimination无法成功地将人工样本和真实样本区分开的时候就会停止运行

前置工作

定义超参数

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch

## 创建文件夹
os.makedirs("./images/", exist_ok=True)         ## 记录训练过程的图片效果
os.makedirs("./save/", exist_ok=True)           ## 训练完成时模型保存的位置
os.makedirs("./datasets/mnist", exist_ok=True)      ## 下载数据集存放的位置

## 超参数配置
n_epochs=50
batch_size=512
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500

## 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

True

下载数据集

## mnist数据集下载
mnist = datasets.MNIST(
    root='./datasets/', train=True, download=True, transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]),
)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw/train-images-idx3-ubyte.gz
100%|██████████| 9912422/9912422 [00:00<00:00, 438576233.89it/s]
Extracting ./datasets/MNIST/raw/train-images-idx3-ubyte.gz to ./datasets/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz
100%|██████████| 28881/28881 [00:00<00:00, 26106830.57it/s]
Extracting ./datasets/MNIST/raw/train-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz
100%|██████████| 1648877/1648877 [00:00<00:00, 95251028.09it/s]
Extracting ./datasets/MNIST/raw/t10k-images-idx3-ubyte.gz to ./datasets/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz
100%|██████████| 4542/4542 [00:00<00:00, 5843720.48it/s]
Extracting ./datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./datasets/MNIST/raw

配置数据

## 配置数据到加载器
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

模型训练

模型定义

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),         # 输入特征数为784,输出为512
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(512, 256),              # 输入特征数为512,输出为256
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(256, 1),                # 输入特征数为256,输出为1
            nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)      # 通过鉴别器网络
        return validity                      # 鉴别器返回的是一个[0, 1]间的概率
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## 模型中间块儿
        def block(in_feat, out_feat, normalize=True):        # block(in, out )
            layers = [nn.Linear(in_feat, out_feat)]          # 线性变换将输入映射到out维
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
            layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非线性激活函数
            return layers
        ## prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
            *block(128, 256),                         # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
            *block(256, 512),                         # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
            *block(512, 1024),                        # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
            nn.Linear(1024, img_area),                # 线性变化将输入映射 1024 to 784
            nn.Tanh()                                 # 将(784)的数据每一个都映射到[-1, 1]之间
        )
    ## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
    def forward(self, z):                           # 输入的是(64, 100)的噪声数据
        imgs = self.model(z)                        # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)
        return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

模型训练

## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()

## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()

## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
    generator     = generator.cuda()
    discriminator = discriminator.cuda()
    criterion     = criterion.cuda()
for epoch in range(n_epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)

        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0


        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out                                             
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数


        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数

        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if ( i + 1 ) % 100 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./images/%d.png" % batches_done, nrow=5, normalize=True)
[Epoch 0/50] [Batch 99/118] [D loss: 1.392201] [G loss: 0.720495] [D real: 0.660287] [D fake: 0.605812]
[Epoch 1/50] [Batch 99/118] [D loss: 1.237947] [G loss: 0.645015] [D real: 0.627908] [D fake: 0.521414]
[Epoch 2/50] [Batch 99/118] [D loss: 1.206743] [G loss: 1.371007] [D real: 0.831680] [D fake: 0.612845]
[Epoch 3/50] [Batch 99/118] [D loss: 1.042226] [G loss: 1.403389] [D real: 0.738673] [D fake: 0.510074]
[Epoch 4/50] [Batch 99/118] [D loss: 0.950092] [G loss: 1.879504] [D real: 0.826950] [D fake: 0.522136]
[Epoch 5/50] [Batch 99/118] [D loss: 0.964685] [G loss: 1.957089] [D real: 0.719685] [D fake: 0.451049]
[Epoch 6/50] [Batch 99/118] [D loss: 1.250396] [G loss: 3.282628] [D real: 0.890914] [D fake: 0.671558]
[Epoch 7/50] [Batch 99/118] [D loss: 1.436018] [G loss: 2.394736] [D real: 0.889253] [D fake: 0.720616]
[Epoch 8/50] [Batch 99/118] [D loss: 0.997427] [G loss: 1.298990] [D real: 0.646482] [D fake: 0.406250]
[Epoch 9/50] [Batch 99/118] [D loss: 1.015665] [G loss: 1.133040] [D real: 0.583219] [D fake: 0.313580]
[Epoch 10/50] [Batch 99/118] [D loss: 1.104344] [G loss: 0.890277] [D real: 0.487496] [D fake: 0.231743]
[Epoch 11/50] [Batch 99/118] [D loss: 0.805449] [G loss: 1.476539] [D real: 0.771647] [D fake: 0.393425]
[Epoch 12/50] [Batch 99/118] [D loss: 0.886394] [G loss: 1.365783] [D real: 0.625857] [D fake: 0.272819]
[Epoch 13/50] [Batch 99/118] [D loss: 1.160892] [G loss: 2.270098] [D real: 0.819089] [D fake: 0.600196]
[Epoch 14/50] [Batch 99/118] [D loss: 0.990126] [G loss: 2.221423] [D real: 0.845885] [D fake: 0.544953]
[Epoch 15/50] [Batch 99/118] [D loss: 0.814652] [G loss: 1.242788] [D real: 0.629658] [D fake: 0.231237]
[Epoch 16/50] [Batch 99/118] [D loss: 1.292980] [G loss: 2.503134] [D real: 0.784579] [D fake: 0.633823]
[Epoch 17/50] [Batch 99/118] [D loss: 1.031758] [G loss: 2.699657] [D real: 0.815461] [D fake: 0.542024]
[Epoch 18/50] [Batch 99/118] [D loss: 0.988402] [G loss: 1.569268] [D real: 0.678417] [D fake: 0.390312]
[Epoch 19/50] [Batch 99/118] [D loss: 1.008053] [G loss: 2.010935] [D real: 0.820956] [D fake: 0.532101]
[Epoch 20/50] [Batch 99/118] [D loss: 0.928145] [G loss: 1.021322] [D real: 0.581234] [D fake: 0.226135]
[Epoch 21/50] [Batch 99/118] [D loss: 0.901849] [G loss: 1.050935] [D real: 0.586989] [D fake: 0.204523]
[Epoch 22/50] [Batch 99/118] [D loss: 0.741626] [G loss: 1.595031] [D real: 0.732048] [D fake: 0.289632]
[Epoch 23/50] [Batch 99/118] [D loss: 1.299593] [G loss: 0.612525] [D real: 0.399023] [D fake: 0.091832]
[Epoch 24/50] [Batch 99/118] [D loss: 0.892590] [G loss: 1.221449] [D real: 0.614085] [D fake: 0.213004]
[Epoch 25/50] [Batch 99/118] [D loss: 0.911475] [G loss: 1.191773] [D real: 0.555970] [D fake: 0.123458]
[Epoch 26/50] [Batch 99/118] [D loss: 0.995571] [G loss: 2.681632] [D real: 0.854862] [D fake: 0.527149]
[Epoch 27/50] [Batch 99/118] [D loss: 1.160800] [G loss: 0.873747] [D real: 0.452850] [D fake: 0.120306]
[Epoch 28/50] [Batch 99/118] [D loss: 0.812079] [G loss: 3.594522] [D real: 0.888997] [D fake: 0.474353]
[Epoch 29/50] [Batch 99/118] [D loss: 0.629567] [G loss: 1.881860] [D real: 0.751697] [D fake: 0.232717]
[Epoch 30/50] [Batch 99/118] [D loss: 0.990794] [G loss: 1.515161] [D real: 0.638545] [D fake: 0.315502]
[Epoch 31/50] [Batch 99/118] [D loss: 0.719803] [G loss: 2.252865] [D real: 0.770139] [D fake: 0.299516]
[Epoch 32/50] [Batch 99/118] [D loss: 0.692556] [G loss: 2.410813] [D real: 0.821995] [D fake: 0.356726]
[Epoch 33/50] [Batch 99/118] [D loss: 0.804586] [G loss: 1.265048] [D real: 0.616957] [D fake: 0.143595]
[Epoch 34/50] [Batch 99/118] [D loss: 1.002946] [G loss: 0.998340] [D real: 0.542722] [D fake: 0.129555]
[Epoch 35/50] [Batch 99/118] [D loss: 0.638164] [G loss: 2.125413] [D real: 0.781333] [D fake: 0.267743]
[Epoch 36/50] [Batch 99/118] [D loss: 0.796073] [G loss: 2.105274] [D real: 0.733458] [D fake: 0.275736]
[Epoch 37/50] [Batch 99/118] [D loss: 0.707837] [G loss: 1.712856] [D real: 0.723462] [D fake: 0.213725]
[Epoch 38/50] [Batch 99/118] [D loss: 0.574488] [G loss: 1.854247] [D real: 0.762389] [D fake: 0.191676]
[Epoch 39/50] [Batch 99/118] [D loss: 0.665478] [G loss: 2.232678] [D real: 0.820965] [D fake: 0.334450]
[Epoch 40/50] [Batch 99/118] [D loss: 0.721808] [G loss: 2.323272] [D real: 0.814131] [D fake: 0.358871]
[Epoch 41/50] [Batch 99/118] [D loss: 0.794077] [G loss: 1.027562] [D real: 0.625926] [D fake: 0.107704]
[Epoch 42/50] [Batch 99/118] [D loss: 0.809034] [G loss: 1.199560] [D real: 0.628133] [D fake: 0.120829]
[Epoch 43/50] [Batch 99/118] [D loss: 0.679062] [G loss: 2.121304] [D real: 0.762760] [D fake: 0.273702]
[Epoch 44/50] [Batch 99/118] [D loss: 0.565462] [G loss: 1.739519] [D real: 0.764866] [D fake: 0.179747]
[Epoch 45/50] [Batch 99/118] [D loss: 0.788362] [G loss: 1.181516] [D real: 0.638391] [D fake: 0.095294]
[Epoch 46/50] [Batch 99/118] [D loss: 0.761360] [G loss: 3.341885] [D real: 0.876044] [D fake: 0.434379]
[Epoch 47/50] [Batch 99/118] [D loss: 0.755073] [G loss: 2.987507] [D real: 0.849711] [D fake: 0.408325]
[Epoch 48/50] [Batch 99/118] [D loss: 0.737976] [G loss: 1.869045] [D real: 0.739559] [D fake: 0.229506]
[Epoch 49/50] [Batch 99/118] [D loss: 0.726861] [G loss: 1.471112] [D real: 0.676479] [D fake: 0.113519]

保存模型

torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值