生成对抗网络GAN 最简单入门学习


欢迎点赞和收藏,后续会继续分享文章,记得关注
本文为365天深度学习训练营中的博客学习总结
原作者:k同学呀
文章来源:K同学的学习圈

学习内容:

  1. 了解对抗生成网络的原理
  2. 学习使用最简单的对抗生成网络GAN

一. 理论基础

通俗理解:左手打右手,两手共提高

例子:警察抓小偷,两者相互竞争,共同提高
例子:老虎抓山羊,共同进化
例子:判别网络判别真伪,生成网络造假,最终共同进化,判别网络判别不了,生成网络以假乱真。

原理解释:

GAN,Generative Adversarial Networks,也即生成对抗网络。并不指代某一个具体的神经网络,而是指一类基于博弈思想:相关竞争共同提高而设计的神经网络。

GAN由两个分别被称为生成器(Generator)和判别器(Discriminator)的神经网络组成。
其中,生成器从某种噪声分布中随机采样作为输入,输出与训练集中真实样本非常相似的人工样本;进行造假。
判别器的输入则为真实样本或人工样本,其目的是将人工样本与真实样本尽可能地区分出来。进行识别。
生成器和判别器交替运行,相互博弈,各自的能力都得到升。理想情况下,经过足够次数的博弈之后,判别器无法判断给定样本的真实性,即对于所有样本都输出50%真,50%假的判断。此时,生成器输出的人工样本已经逼真到使判别器无法分辨真假,停止博弈。这样就可以得到一个具有“伪造”真实样本能力的生成器。

简单GAN(对抗生成网络)讲解

  1. 最简单的判别网络即为图片二分类网络,由一些卷积层,线性层,激活层构成。可以替换成任意高级的网络
  2. 最简单的生成网络 即为生成图片拉长维度向量的网络,主要由线性层构成。先生成100维的随机向量,每个数值为正态分布,之后线性层依次升维即可。类似于U-Net的后半部分。也类似与AE,或者transformer 中的后半部分

二. 代码解读

1. 定义超参

import argparse
import os
import numpy as np
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable
import torch.nn as nn
import torch

## 创建文件夹
os.makedirs("./data/images/", exist_ok=True)         ## 记录训练过程的图片效果
os.makedirs("./data/save/", exist_ok=True)           ## 训练完成时模型保存的位置
os.makedirs("./data/mnist", exist_ok=True)      ## 下载数据集存放的位置

## 超参数配置
n_epochs=50
batch_size=64
lr=0.0002
b1=0.5
b2=0.999
n_cpu=2
latent_dim=100
img_size=28
channels=1
sample_interval=500

## 图像的尺寸:(1, 28, 28),  和图像的像素面积:(784)
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)

## 设置cuda:(cuda:0)
cuda = True if torch.cuda.is_available() else False
print(cuda)

2. 下载数据

## mnist数据集下载
mnist = datasets.MNIST(
    root='./data/', train=True, download=True, transform=transforms.Compose(
            [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]), 
)

3. 配置数据

## 配置数据到加载器
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,
)

4. 配置模型

4.1 鉴别器

##### 定义判别器 Discriminator ######
## 将图片28x28展开成784,然后通过多层感知器,中间经过斜率设置为0.2的LeakyReLU激活函数,
## 最后接sigmoid激活函数得到一个0到1之间的概率进行二分类
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),         # 输入特征数为784,输出为512
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(512, 256),              # 输入特征数为512,输出为256
            nn.LeakyReLU(0.2, inplace=True),  # 进行非线性映射
            nn.Linear(256, 1),                # 输入特征数为256,输出为1
            nn.Sigmoid(),                     # sigmoid是一个激活函数,二分类问题中可将实数映射到[0, 1],作为概率值, 多分类用softmax函数
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1) # 鉴别器输入是一个被view展开的(784)的一维图像:(64, 784)
        validity = self.model(img_flat)      # 通过鉴别器网络
        return validity                      # 鉴别器返回的是一个[0, 1]间的概率

4.2 生成器

###### 定义生成器 Generator #####
## 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维,
## 然后通过LeakyReLU激活函数,接着进行一个线性变换,再经过一个LeakyReLU激活函数,
## 然后经过线性变换将其变成784维,最后经过Tanh激活函数是希望生成的假的图片数据分布, 能够在-1~1之间。
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        ## 模型中间块儿
        def block(in_feat, out_feat, normalize=True):        # block(in, out )
            layers = [nn.Linear(in_feat, out_feat)]          # 线性变换将输入映射到out维
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化
            layers.append(nn.LeakyReLU(0.2, inplace=True))   # 非线性激活函数
            return layers
        ## prod():返回给定轴上的数组元素的乘积:1*28*28=784
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False), # 线性变化将输入映射 100 to 128, 正则化, LeakyReLU
            *block(128, 256),                         # 线性变化将输入映射 128 to 256, 正则化, LeakyReLU
            *block(256, 512),                         # 线性变化将输入映射 256 to 512, 正则化, LeakyReLU
            *block(512, 1024),                        # 线性变化将输入映射 512 to 1024, 正则化, LeakyReLU
            nn.Linear(1024, img_area),                # 线性变化将输入映射 1024 to 784
            nn.Tanh()                                 # 将(784)的数据每一个都映射到[-1, 1]之间
        )
    ## view():相当于numpy中的reshape,重新定义矩阵的形状:这里是reshape(64, 1, 28, 28)
    def forward(self, z):                           # 输入的是(64, 100)的噪声数据
        imgs = self.model(z)                        # 噪声数据通过生成器模型
        imgs = imgs.view(imgs.size(0), *img_shape)  # reshape成(64, 1, 28, 28)
        return imgs                                 # 输出为64张大小为(1, 28, 28)的图像

5. 创建实例并训练

#创建实例
## 创建生成器,判别器对象
generator = Generator()
discriminator = Discriminator()

## 首先需要定义loss的度量方式  (二分类的交叉熵)
criterion = torch.nn.BCELoss()

## 其次定义 优化函数,优化函数的学习率为0.0003
## betas:用于计算梯度以及梯度平方的运行平均值的系数
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

## 如果有显卡,都在cuda模式中运行
if torch.cuda.is_available():
    generator     = generator.cuda()
    discriminator = discriminator.cuda()
    criterion     = criterion.cuda()

# 训练
## 进行多个epoch的训练
for epoch in range(n_epochs):                   # epoch:50
    for i, (imgs, _) in enumerate(dataloader):  # imgs:(64, 1, 28, 28)     _:label(64)
        
        ## =============================训练判别器==================
        ## view(): 相当于numpy中的reshape,重新定义矩阵的形状, 相当于reshape(128,784)  原来是(128, 1, 28, 28)
        imgs = imgs.view(imgs.size(0), -1)    # 将图片展开为28*28=784  imgs:(64, 784)
        real_img = Variable(imgs).cuda()      # 将tensor变成Variable放入计算图中,tensor变成variable之后才能进行反向传播求梯度
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()      ## 定义真实的图片label为1
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()     ## 定义假的图片的label为0

        ## ---------------------
        ##  Train Discriminator
        ## 分为两部分:1、真的图像判别为真;2、假的图像判别为假
        ## ---------------------
        ## 计算真实图片的损失
        real_out = discriminator(real_img)            # 将真实图片放入判别器中
        loss_real_D = criterion(real_out, real_label) # 得到真实图片的loss
        real_scores = real_out                        # 得到真实图片的判别值,输出的值越接近1越好
        ## 计算假的图片的损失
        ## detach(): 从当前计算图中分离下来避免梯度传到G,因为G不用更新
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 随机生成一些噪声, 大小为(128, 100)
        fake_img    = generator(z).detach()                                    ## 随机噪声放入生成网络中,生成一张假的图片。 
        fake_out    = discriminator(fake_img)                                  ## 判别器判断假的图片
        loss_fake_D = criterion(fake_out, fake_label)                       ## 得到假的图片的loss
        fake_scores = fake_out                                              ## 得到假图片的判别值,对于判别器来说,假图片的损失越接近0越好
        ## 损失函数和优化
        loss_D = loss_real_D + loss_fake_D  # 损失包括判真损失和判假损失
        optimizer_D.zero_grad()             # 在反向传播之前,先将梯度归0
        loss_D.backward()                   # 将误差反向传播
        optimizer_D.step()                  # 更新参数

        ## -----------------
        ##  Train Generator
        ## 原理:目的是希望生成的假的图片被判别器判断为真的图片,
        ## 在此过程中,将判别器固定,将假的图片传入判别器的结果与真实的label对应,
        ## 反向传播更新的参数是生成网络里面的参数,
        ## 这样可以通过更新生成网络里面的参数,来训练网络,使得生成的图片让判别器以为是真的, 这样就达到了对抗的目的
        ## -----------------
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()      ## 得到随机噪声
        fake_img = generator(z)                                             ## 随机噪声输入到生成器中,得到一副假的图片
        output = discriminator(fake_img)                                    ## 经过判别器得到的结果
        ## 损失函数和优化
        loss_G = criterion(output, real_label)                              ## 得到的假的图片与真实的图片的label的loss
        optimizer_G.zero_grad()                                             ## 梯度归0
        loss_G.backward()                                                   ## 进行反向传播
        optimizer_G.step()                                                  ## step()一般用在反向传播后面,用于更新生成网络的参数

        ## 打印训练过程中的日志
        ## item():取出单元素张量的元素值并返回该值,保持原元素类型不变
        if (i + 1) % 300 == 0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [D real: %f] [D fake: %f]"
                % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item(), real_scores.data.mean(), fake_scores.data.mean())
            )
        ## 保存训练过程中的图像
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            save_image(fake_img.data[:25], "./data/images/%d.png" % batches_done, nrow=5, normalize=True)

[Epoch 40/50] [Batch 299/938] [D loss: 0.911972] [G loss: 1.264209] [D real: 0.615513] [D fake: 0.190590]
[Epoch 40/50] [Batch 599/938] [D loss: 0.808387] [G loss: 1.492971] [D real: 0.700173] [D fake: 0.242417]
[Epoch 40/50] [Batch 899/938] [D loss: 0.945543] [G loss: 1.199135] [D real: 0.641663] [D fake: 0.264667]
[Epoch 41/50] [Batch 299/938] [D loss: 0.767059] [G loss: 1.265899] [D real: 0.729328] [D fake: 0.256548]
[Epoch 41/50] [Batch 599/938] [D loss: 0.768910] [G loss: 1.395854] [D real: 0.698200] [D fake: 0.252217]
[Epoch 41/50] [Batch 899/938] [D loss: 0.828038] [G loss: 1.425793] [D real: 0.676624] [D fake: 0.215663]
[Epoch 42/50] [Batch 299/938] [D loss: 0.786111] [G loss: 1.149347] [D real: 0.679647] [D fake: 0.240033]
[Epoch 42/50] [Batch 599/938] [D loss: 0.926057] [G loss: 1.093621] [D real: 0.611771] [D fake: 0.213749]
[Epoch 42/50] [Batch 899/938] [D loss: 1.051074] [G loss: 1.875097] [D real: 0.780975] [D fake: 0.452756]
[Epoch 43/50] [Batch 299/938] [D loss: 0.949325] [G loss: 1.513711] [D real: 0.725759] [D fake: 0.343389]
[Epoch 43/50] [Batch 599/938] [D loss: 1.072763] [G loss: 1.187461] [D real: 0.599758] [D fake: 0.256029]
[Epoch 43/50] [Batch 899/938] [D loss: 0.936638] [G loss: 1.526531] [D real: 0.653215] [D fake: 0.225740]
[Epoch 44/50] [Batch 299/938] [D loss: 0.783098] [G loss: 1.213412] [D real: 0.683519] [D fake: 0.254234]
[Epoch 44/50] [Batch 599/938] [D loss: 1.032195] [G loss: 1.537429] [D real: 0.784595] [D fake: 0.430993]
[Epoch 44/50] [Batch 899/938] [D loss: 0.841602] [G loss: 1.410475] [D real: 0.681418] [D fake: 0.202322]
[Epoch 45/50] [Batch 299/938] [D loss: 0.913109] [G loss: 1.560791] [D real: 0.653888] [D fake: 0.276527]
[Epoch 45/50] [Batch 599/938] [D loss: 0.827896] [G loss: 1.580814] [D real: 0.702827] [D fake: 0.265357]
[Epoch 45/50] [Batch 899/938] [D loss: 0.721954] [G loss: 1.603410] [D real: 0.723822] [D fake: 0.233086]
[Epoch 46/50] [Batch 299/938] [D loss: 0.811522] [G loss: 1.570267] [D real: 0.741978] [D fake: 0.327063]
[Epoch 46/50] [Batch 599/938] [D loss: 1.024942] [G loss: 1.180711] [D real: 0.570806] [D fake: 0.164485]
[Epoch 46/50] [Batch 899/938] [D loss: 0.845878] [G loss: 1.584793] [D real: 0.627794] [D fake: 0.176946]
[Epoch 47/50] [Batch 299/938] [D loss: 1.037055] [G loss: 1.048466] [D real: 0.606069] [D fake: 0.248434]
[Epoch 47/50] [Batch 599/938] [D loss: 0.873520] [G loss: 1.529568] [D real: 0.665836] [D fake: 0.266380]
[Epoch 47/50] [Batch 899/938] [D loss: 0.909397] [G loss: 0.988783] [D real: 0.699651] [D fake: 0.329167]
[Epoch 48/50] [Batch 299/938] [D loss: 0.762955] [G loss: 1.757964] [D real: 0.691470] [D fake: 0.175175]
[Epoch 48/50] [Batch 599/938] [D loss: 0.751731] [G loss: 1.582819] [D real: 0.695594] [D fake: 0.232840]
[Epoch 48/50] [Batch 899/938] [D loss: 0.920607] [G loss: 2.062231] [D real: 0.841761] [D fake: 0.449201]
[Epoch 49/50] [Batch 299/938] [D loss: 0.803595] [G loss: 1.586111] [D real: 0.708683] [D fake: 0.264936]
[Epoch 49/50] [Batch 599/938] [D loss: 0.767607] [G loss: 1.513981] [D real: 0.702909] [D fake: 0.244586]
[Epoch 49/50] [Batch 899/938] [D loss: 1.228767] [G loss: 0.696118] [D real: 0.512739] [D fake: 0.159993]

6. 保存模型

## 保存模型
torch.save(generator.state_dict(), './data/save/generator.pth')
torch.save(discriminator.state_dict(), './data/save/discriminator.pth')
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值