如何用PyTorch构建GAN?

点击上方“小白学视觉”,选择加"星标"或“置顶”
重磅干货,第一时间送达

作者 | Ta-ying Cheng

译者 | Sambodhi

转自 | AI前线

生成对抗网络(Generative Adversarial Network,GAN)由 Goodfellow 等人在 2014 年提出,它彻底改变了计算机视觉中的图像生成领域:没有人能够相信这些令人惊叹而生动的图像实际上是纯粹由机器生成的。

事实上,人们曾经认为生成的任务是不可能的,并且被 GAN 的力量所震惊,因为传统上,根本没有任何事实可以比较我们生成的图像。

本文介绍了创建 GAN 背后的简单直觉,然后介绍了通过 PyTorch 实现的卷积 GAN 及其训练过程。

GAN 背后的直觉

不同于传统分类方法,我们的网络预测可以直接与事实的正确答案相比较,而生成图像的“正确性”是很难定义和衡量的。Goodfellow 等人在他们的原创论文《生成对抗网络》(Generative Adversarial Network)中提出了一个有趣的想法:使用经过训练的分类器来区分生成的图像和实际图像。如果存在这样的分类器,我们可以创建并训练一个生成器网络,直到它输出的图像能完全骗过分类器。

ba9f6b8211aa42063d12ede58de0616f.jpeg

图 1 GAN 管道

GAN 是这一过程的产物:它包含一个根据给定的数据集生成图像的生成器,以及一个区分图像是真实的还是生成的判别器(分类器)。GAN 的详细管道见图 1。

损失函数

对生成器和判别器进行优化都很困难,因为正如你所想象的那样,这两个网络的目标完全相反:生成器希望尽可能地创造出真实的东西,但判别器希望区分生成的材料。

为了说明这一点,我们让 D(x) 是判别器的输出,也就是 x 是真实图像的概率,而 G(z) 是我们的生成器的输出。判别器类似于一个二元分类器,因此判别器的目标是使函数最大化:

本质上是二元交叉熵损失,没有开头的负号。另一方面,生成器的目标是使判别器做出正确判断的机会最小化,因此它的目标是最小化函数。所以,最终的损失函数将是两个分类器之间的一个极小极大博弈(minimax game),具体如下:

457c173d4766fb6fb836abfa5d4c9a5a.jpeg

从理论上讲,这将收敛到判别器,预测所有事件的概率为 0.5。

但在实践中,极小极大博弈往往会导致网络无法收敛,因此仔细调整训练过程非常重要。像学习率这样的超参数对于训练 GAN 时显然更为重要:一个微小的变化会导致 GAN 产生一个输出,而与输入噪声无关。

运算环境

我们通过 PyTorch 库(包括 torchvision)来构建整个程序。GAN 的生成结果的可视化是通过 Matplotlib 库绘制的。下面的代码导入了所有的库:

importGAN.py

"""
Import necessary libraries to create a generative adversarial network
The code is mainly developed using the PyTorch library
"""
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import transforms
from model import discriminator, generator
import numpy as np
import matplotlib.pyplot as plt

数据集

在 GAN 训练中,数据集是一个重要方面。图像的非结构化性质意味着任何给定的类别(如狗、猫或手写的数字)都可以有一个可能的数据分布,而这种分布最终是 GAN 生成内容的基础。

为了演示,本文将使用最简单的 MNIST 数据集,其中包含 60000 张从 0 到 9 的手写数字图像。事实上,像 MNIST 这样的非结构化数据集可以在 Graviti 上找到。这是一家年轻的创业公司,他们希望通过非结构化数据集为社区提供帮助,在他们的 平台 上有一些最好的公共非结构化数据集,包括 MNIST。

硬件要求

最好的方法是用 GPU 训练神经网络,它可以显著地提高训练速度。但是,如果只有 CPU 可用,你仍然可以测试程序。要使你的程序能够自行确定硬件,你可以使用以下方法:

torchDevice.py

"""
Determine if any GPUs are available
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

实施

网络架构

由于数字的简单性,这两种架构——判别器和生成器,都是由全连接层构建的。请注意,在某些情况下,全连接的 GAN 也比 DCGAN 略微容易收敛。

以下是两种架构的 PyTorch 实现:

GANArchitecture.py

"""
Network Architectures
The following are the discriminator and generator architectures
"""

class discriminator(nn.Module):
    def __init__(self):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(784, 512)
        self.fc2 = nn.Linear(512, 1)
        self.activation = nn.LeakyReLU(0.1)

    def forward(self, x):
        x = x.view(-1, 784)
        x = self.activation(self.fc1(x))
        x = self.fc2(x)
        return nn.Sigmoid()(x)

class generator(nn.Module):
    def __init__(self):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(128, 1024)
        self.fc2 = nn.Linear(1024, 2048)
        self.fc3 = nn.Linear(2048, 784)
        self.activation = nn.ReLU()

    def forward(self, x):
        x = self.activation(self.fc1(x))
        x = self.activation(self.fc2(x))
        x = self.fc3(x)
        x = x.view(-1, 1, 28, 28)
        return nn.Tanh()(x)

训练

在训练 GAN 时,我们优化了判别器的结果,同时也改进了我们的生成器。这样,在每次迭代过程中会有两个相互矛盾的损失来同时优化它们。我们送入生成器的是随机噪声,而生成器理应根据给定噪声的微小差异来生成图像:

trainGAN.py

"""
Network training procedure
Every step both the loss for disciminator and generator is updated
Discriminator aims to classify reals and fakes
Generator aims to generate images as realistic as possible
"""
for epoch in range(epochs):
    for idx, (imgs, _) in enumerate(train_loader):
        idx += 1

        # Training the discriminator
        # Real inputs are actual images of the MNIST dataset
        # Fake inputs are from the generator
        # Real inputs should be classified as 1 and fake as 0
        real_inputs = imgs.to(device)
        real_outputs = D(real_inputs)
        real_label = torch.ones(real_inputs.shape[0], 1).to(device)

        noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5
        noise = noise.to(device)
        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device)

        outputs = torch.cat((real_outputs, fake_outputs), 0)
        targets = torch.cat((real_label, fake_label), 0)

        D_loss = loss(outputs, targets)
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()

        # Training the generator
        # For generator, goal is to make the discriminator believe everything is 1
        noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5
        noise = noise.to(device)

        fake_inputs = G(noise)
        fake_outputs = D(fake_inputs)
        fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device)
        G_loss = loss(fake_outputs, fake_targets)
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()

        if idx % 100 == 0 or idx == len(train_loader):
            print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item()))

    if (epoch+1) % 10 == 0:
        torch.save(G, 'Generator_epoch_{}.pth'.format(epoch))
        print('Model saved.')

结  果

当 100 个轮数(epoch)之后,我们可以绘制数据集,并看到从随机噪音中生成的数字的结果:

de4d351125e84a214dddc6fed9665f32.jpeg

图 2:GAN 生成的结

如上图所示,生成的结果看起来确实相当像真实的结果。鉴于网络非常简单,所以结果看起来确实很有希望!

超越单纯的内容创作

GAN 的创造与计算机视觉领域的先前工作如此不同。随后的众多应用使学术界对深度网络的能力感到惊讶。下面将介绍一些令人惊讶的工作。

CycleGAN

Zhu 等人的 CycleGAN 引入了一种概念,它无需配对样本就可以将图像从 X 域翻译成 Y 域。马被转化为斑马,夏日的阳光被转化为暴风雪,CycleGAN 的结果令人惊讶且准确。

5feae0ed4d6f1367cdd36e4c36b8034b.jpeg

图 3:Zhu 等人的 CycleGAN 生成的结果。

GauGAN

Nvidia 利用 GAN 的力量,把简单的绘画,根据画笔的语义,转换成优雅而逼真的照片。尽管训练资源的计算成本很高,但它创造了一个全新的研究和应用领域。

77c58dc548d3e3e2ecf09a81b0d1ac0e.jpeg

图 4:GaoGAN 的生成结果。左为原图,右为生成的结果。

AdvGAN

GAN 还扩展到清理对抗性图像,并将其转化为不会欺骗分类器的干净样本。关于对抗性攻击和防御的更多信息可以在 这里 到。

结  语

所以,你已经拥有了它!希望这篇文章对如何构建 GAN 提供了一个概览。完整的实现可以在下面的Github 资源库中找到:

https://github.com/ttchengab/MnistGAN

作者简介:

Ta-ying Cheng,中国香港人,牛津大学哲学博士新生,爱好 3D 视觉、深度学习。

原文链接:

https://towardsdatascience.com/building-a-gan-with-pytorch-237b4b07ca9a

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇

 
 

6c4e226ecd21e652de3813e2237b9248.jpeg

下载1:OpenCV-Contrib扩展模块中文版教程

在「小白学视觉」公众号后台回复:扩展模块中文教程,即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。


下载2:Python视觉实战项目52讲
在「小白学视觉」公众号后台回复:Python视觉实战项目,即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。


下载3:OpenCV实战项目20讲
在「小白学视觉」公众号后台回复:OpenCV实战项目20讲,即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。


交流群

欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器、自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN、算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值