根据数字生成唯一数字_经典GAN实战教程:理解并运行自己的GAN生成手写数字

4b267b5ffa4cefb36e91bd2e2b98d571.png

  新智元报道   

来源:Medium

编辑:元子

【新智元导读】本文从GAN的简介、理解和评估GAN、运行自己的GAN3个方面,试图让读者理解生成性对抗性网络(GAN)研究和评估是如何开发的,然后实现用自己的GAN来生成手写数字。偏实战向,针对读者为学生、开发者以及对动手操作GAN感兴趣的读者。

本文主要是以下3个部分:

  1. 了解什么是GAN

  2. 理解和评估GAN

  3. 运行自己的GAN

希望通过本文,读者能够了解如何评估GAN,并最终能够动手运行自己的GAN生成MNIST等手写数字。

7f78ee8e8a3f85adf24ac1dda0348ca9.png

GAN的简要介绍

自2014年Ian Goodfellow的发布“Generative Adversarial Networks”论文成立以来,GAN的进展申诉,应用场景也越来越广泛。

2e5fa7f492d9cb2f1131d3cd1d096c9c.png

就在三年前,GAN之父Ian Goodfellow认为GAN仅针对实值数据进行了定义。由于所有NLP都基于离散值,如单词,字符或字节,因此没有人真正知道如何将GAN应用于NLP。

d0a8e19e895ba6d5cec1d48fcd97fb19.png

然而现在GAN用于创建各种内容,包括图像,视频,音频和文本。这些输出可用作训练其他模型的合成数据,或仅用于产生有趣的side项目,例如

  • https://thispersondoesnotexist.com/

  • https://thisairbnbdoesnotexist.com/

GAN是什么?

GAN由一个生成网络与一个判别网络组成。

2a2ea3d014526aec397814e12ba0fc9a.png

生成网络从潜在空间(latent space)中随机采样作为输入,其输出结果需要尽量模仿训练集中的真实样本。判别网络的输入则为真实样本或生成网络的输出,其目的是将生成网络的输出从真实样本中尽可能分辨出来。而生成网络则要尽可能地欺骗判别网络。两个网络相互对抗、不断调整参数,最终目的是使判别网络无法判断生成网络的输出结果是否真实。

一句话解释:G要骗过A。

如果你需要对GAN有更深入的理解,推荐你学习以下课程:

  • Stanford CS231 Lecture 13 — Generative Models

  • Style-based GANs

  • Understanding Generative Adversarial Networks

  • Introduction to Generative Adversarial Networks

  • Lillian Weng: From Gan to WGAN

  • Dive head first into advanced GANs: exploring self-attention and spectral norm

  • Guim Perarnau: Fantastic GANs and where to find them (Parts I & II)

理解和评估GAN

量化GAN的进度感觉上非常主观,“这个生成的面部是否看起来足够逼真?”、“这些生成的图像是否足够多样化?”而且GAN的黑盒属性使得我们并不清楚模型的哪些组件对学习过程和结果质量造成影响。

为此,MIT计算机科学与人工智能(CSAIL)实验室的一个小组最近发表了一篇论文《GAN Dissection: Visualizing and Understanding Generative Adversarial Networks》,该论文介绍了一种可视化GAN,以及GAN units是如何像对象之间的相关性一样,形成与图像中对象的相关性。

81f6bd2c700bcc3060a6057cffb4b651.png

通过介入某些GAN unit显示图像修改

论文提出使用基于分割的网络剖析方法,允许对generator神经网络的内部工作进行剖析和可视化。这步操作可以通过寻找一组GAN unit(称为神经元)与输出图像中的概念(例如树、天空、云等)之间的协议来实现。因此,我们能够识别出对某些物体(如建筑物或云)起作用的神经元。

将这种粒度级别放入神经元中,可以允许通过强制激活和去激活(消融)这些对象的相应unit,来编辑现有图像(例如,添加或移除图像中所示的树)。

然而目前尚不清楚网络是否能够推断场景中的对象,或者它是否只是记住这些对象。接近这个问题答案的一种方法,是试图以不切实际的方式扭曲图像。

也许MIT CSAIL的GAN Paint互动网络演示中,最令人印象深刻的部分是该模型似乎能够将这些编辑限制在“真实感”。例如试图将草坪放在天空上,结果就会变成这样:

2058626ecd15725372513256bd9e30d5.gif

即使我们激活相应的神经元,看起来GAN已经抑制了后续层中的信号。

c3e6b78f68c7571a1f7c879158747943.png

上图演示了在建筑物上生成门而非在生成到树上或在天空中的可能性

可视化GAN的另一种有趣方式是可以进行潜在空间插值(记住,GAN通过从学习的潜在空间中采样来生成新实例),这可能是查看生成的样本之间过渡的平滑程度的有用方法。

这些可视化可以帮助我们理解GAN的内部表示,但是找到可量化的方法来理解GAN进度和输出质量仍然是一个亟待解决的研究领域。

98ab10fcd12d3f2b2075bf09ca299430.gif

图像质量和多样性的两个常用评估指标:Inception Score(IS)和Fréchet Inception Distance(FID):

Inception Score

Inception Score由Salimans等人提出,并在2016年发表论文《用于训练GAN的改进技术》。

Inception Score基于一种启发式方法,即当通过预先训练的网络(例如ImageNet上的Inception)传递时,现实样本应该能够被分类。 从技术上讲,这意味着样本应具有低熵softmax预测向量。

除了高可预测性(低熵)之外,Inception Score还基于所生成的样本的多样性来评估GAN(例如,生成的样本的分布的高方差或熵)。 

如果这两个特征都得到满足,那么应该有一个很大的Inception Score。结合这两个标准的方法是评估样本的条件标签分布与所有样本的边际分布之间的Kullback-Leibler(KL)差异。

Fréchet Inception Distance

FID由Heusel等人在2017年提出。FID通过测量生成的图像分布与真实分布之间的距离来估计真实感。

FID将一组生成的样本嵌入由特定初始网络层给出的特征空间中。该嵌入层被视为连续的多元高斯,然后估计生成的数据和实际数据的均值和协方差。然后使用这两个高斯之间的Fréchet距离(也被称为 Wasserstein-2距离)来量化生成的样本的质量。较低的FID对应于更相似的实际和生成的样本。

值得注意的是,FID需要一个合适的样本量才能产生良好的结果(建议的大小=50k样本)。如果样本太少,实际FID会被高估,并且估算值差异很大。

一般来说,研究人员对不同领域的评估指标采取不同的方法。对于文本GAN,Guy Tevet和他的团队提出使用传统的基于概率的语言模型度量来评估。

运行自己的GAN

本次练习,主要通过网站comet.ml来实现。Comet.ml允许数据科学家和开发人员轻松监视、比较和优化其机器学习模型。

7f78ee8e8a3f85adf24ac1dda0348ca9.png利用可视化损耗和准确度曲线以及使用Comet.ml检查测试输出来跟踪我们的GAN进度

该GAN模型将MNIST训练数据和随机噪声作为输入(具体为噪声的随机向量)来生成以下内容:

  • 图像(在这种情况下,手写数字的图像)。 最终,这些生成的图像将类似于MNIST数据集的数据分布

  • discriminator对生成的图像的预测

Generator和Discriminator模型一起形成对抗模型,对于这个例子,如果对抗模型将生成的图像分类为所有输入的实数,则generator将表现良好。

完整代码请复制以下?链接clone:

https://gist.github.com/ceceshao1/935ea6000c8509a28130d4c55b32fcd6

追踪模型进度

可以使用Comet.ml跟踪Generator和Discriminator模型的训练进度。

在绘制discriminator和对抗模型的准确性和损失过程中,要跟踪的最重要指标是:

  • 鉴别者的损失(见右图中的蓝线) - dis_loss

  • 对抗模型的准确性(见左图中的蓝线) - acc_adv

复制?链接来查看此实验的培训进度:

https://www.comet.ml/ceceshao1/mnist-gan/cf310adacd724bf280323e2eef92d1cd/chart

88cf4309e57ea9aca100f7907505d9bc.png

我们还需要确认培训过程实际上是否正在使用GPU,可以在Comet System Metrics选项卡中查看。

3351478c151f48dc303abf866efd69d5.png

请注意到我们的for循环语句包括从测试向量报告图像的代码,部分原因是可以直观地分析generator和discriminator模型在生成逼真的手写数字方面的表现,并正确地将生成的数字分类为“真实”或“假”,分别。

if i % 500 == 0:# Visualize the performance of the generator by producing images from the test vector images = net_generator.predict(vis_noise)# Map back to original range#images = (images + 1 ) * 0.5 plt.figure(figsize=(10,10)) for im in range(images.shape[0]): plt.subplot(4, 4, im+1) image = images[im, :, :, :] image = np.reshape(image, [28, 28]) plt.imshow(image, cmap='gray') plt.axis('off') plt.tight_layout()# plt.savefig('/home/ubuntu/cecelia/deeplearning-resources/output/mnist-normal/{}.png'.format(i)) plt.savefig(r'output/mnist-normal/{}.png'.format(i)) experiment.log_image(r'output/mnist-normal/{}.png'.format(i)) plt.close('all')

复制以下?链接来看看这些生成的结果!

https://www.comet.ml/ceceshao1/mnist-gan

可以看到Generator模型如何从这个模糊的灰色输出(参见下面的0.png)开始,它看起来并不像我们期望的手写数字。

40d02c959b1d4688970d30c1cf17854d.png

随着培训的进行和我们模型的损失下降,生成的数字变得更加清晰。 查看生成的输出:

500步后:

ff1c034b9e5ec0ae0373e8547cedea21.png

1000步后:

0b48e97f2020fb59cddb5f2cfa8638ce.png

1500步后:

df62aadf2985dc301f756ba9e72ce78d.png

10000步以后,注意红色框中的数字

e54c4c35a732ebb0e1d21b8ea554ee13.png

完成训练后海可以在Comet的图形选项卡中查看我们报告的输出作为视频(只需按下播放按钮!)。

be2482eb11bc0b5707e250c88df3130b.gif

为了完成最后的实验步骤,确保运行了以下?命令来查看有关模型和GPU使用情况的一些摘要统计信息:

experiment.end()

41e9274944f824be55a94c149c02fa20.png

模型迭代

模型可以更长时间地训练以查看它如何影响性能,感兴趣的话还可以尝试使用几个不同的参数进行迭代,比如:discriminator的优化器、学习率、dropout概率、batch大小等等。

可以通过Comet重新生成一个测试网址,这样就可以对差异进行比较:

7639029876a9e7dff194d66b1bfba92b.png查看两个实验的超参数之间的差异

参考链接:

https://towardsdatascience.com/graduating-in-gans-going-from-understanding-generative-adversarial-networks-to-running-your-own-39804c283399


【2019新智元 AI 技术峰会精彩回顾

2019年3月27日,新智元再汇AI之力,在北京泰富酒店举办AI开年盛典——2019新智元AI技术峰会。峰会以“智能云•芯世界“为主题,聚焦智能云和AI芯片的发展,重塑未来AI世界格局。

同时,新智元在峰会现场权威发布若干AI白皮书,聚焦产业链的创新活跃,评述AI独角兽影响力,助力中国在世界级的AI竞争中实现超越。

嘉宾精彩演讲:

a02f01c646ce98700eeba1904a24578a.png

103db371d41d7c3d1394148d08cab216.png

a22d9feb5e218a464d491319fcb1b40c.png

bd0015edb80dfd086a7c767603f9dc1e.png

6310dcb7041cd8f0dc16b59da6d1f007.png

fac7d90d08a71af4334c753295c0da37.png

f55d0f9c11e26366b3803cb712571a34.png

8c66f6598238d30c6cb665e8c73c35ba.png

addfcd0602cf04655f676936d11924b1.png

b0f5d7237c2a0b695b5eee4dcbf7c97f.gif

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值