概要
GAN(Generative Adversarial Network)生成对抗网络。其实要理解GAN的构想逻辑并不难,像其他的一些模型比如说最最基础的nn.Linear() + nn.ReLU()
,或者是RNN模型,我们不妨把这个模型看成一位武侠,他的目的是要跟江湖上尽可能多的人(data)过招(train),目的是在未来遇到邪恶的坏蛋(真实情景应用)时能够一招制敌(给出正确的结果)。
但是天不遂人愿,在茫茫的人海中,真正的武林高手有几个?又有几个能被我遇到?今天打过了丐帮的降龙十八掌,明天谁知道会不会被一记九阳神功拍的头昏眼花?(能接触到的数据总是有限的)武侠仰天沉思,他想起那一年去西域,自己仗着在中原打遍天下无敌手(过拟合)四处张扬得不行,结果被一旁的扫地大爷一套西洋拳术带走(模型不适应其他数据)。
可是家中有老母要照顾,忠孝难两全,武侠也因此一直呆在中原。由于放眼神州已无敌手,便打起了木人桩。机会总是留给有准备的人,有一天武侠捡到了阿拉丁神灯,神灯答应了他的愿望,点化了他的木人桩,让他能主动与武侠打斗并且不断增强自己的武力值,直击武侠痛点。武侠大喜,从此开始了与被点化的木人桩的切磋之路,技艺日增,终成一代地球大侠。
这个木人桩和武侠就是GAN中的Generator和Discriminator。对于Discriminator而言,它的目标是分辨出真假数据,对于Generator而言,它的目标是要制造出能以假乱真的数据。在学习的过程中,Generator的输入我们用torch.randn
产生随机数据,以此希望通过Generator产生各种各样的输入。
简单地说,二者的目标总结为:
-
Discriminator: 给定数据x,我希望分辨出这是真实产生的数据,还是Generator模拟的假数据,输出0-1
-
Generator: 给定随机数random,我希望能蒙混过关,尽可能模拟真实数据 output.shape == x.shape
二者在训练的过程中我们应该可以看到两边的loss大致是一个此消彼长的关系,这也是GAN中A(Adversarial)的本意,两者对抗。
一个例子 (base on MNIST)
我用在暑假跟着学深度学习中一课时的代码复现给大家分享一下。MNIST数据集是一个图片集,都是手写的单个数字,有images和labels两个部分。用torch.utils.data.Dataset
或者torchvision.Datasets.MNIST
可以读入为dataset
实例,进一步构造dataloader
。废话不多说,上代码。
# 一些经常用的库和函数
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
# 定义一些超参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 64
latent_size = 64 # 就是上文中生成的random的长度
image_size = 28*28 # 这是MNIST数据集中图片的大小
hidden_size = 256 # 定义Discriminator和Generator模型中的隐层的大小
output_size = 10 # 最终输出为10维的向量,代表1~