【第十一周】李宏毅机器学习笔记09:生成式对抗网络1

摘要

本周学习了生成式对抗网络,了解了生成器和判别器的基础原理,同时学习了 GAN 的训练过程和算法思想。最后了解了 Wasserstein Distance,学习了 WGAN 的一点基础。

Abstract

This week, I learned about Generative Adversarial Networks and understood the basic principles of the generator and discriminator. I also studied the training process and the conceptual ideas behind GAN algorithms. Finally, I was introduced to the concept of Wasserstein Distance and gained some foundational knowledge about WGAN.

1.Generative Adversarial Network

1.1.Generator

在这里插入图片描述

之前我们学习到呢 network 是这样子的:输入一个 x x x 就会输出一个固定的 y y y。我们在这个原则上添加一个简单的分布,比如高斯分布。每次输入一个 x x x 的时候会从分布中采样一个值与 x x x 一起丢进 network 中,使得每次生成的 y y y 值都不固定,并且输出出来的值都是一个复杂的分布,我们把这种 network 称为 generator。

生成器(Generator):生成器的目标是生成尽可能接近真实数据的假数据。它通常是一个神经网络,通过学习真实数据的分布特征,生成新的数据样本。生成器的输入是随机噪声(通常来自高斯分布),输出是生成的数据样本。

那究竟为什么我们要使得 y y y 的输入不固定呢?

在这里插入图片描述

考虑这样一个例子:如上图这是一个小精灵游戏的视频,我们可以把这个视频的前几帧丢到一个 network 里面,然后预测出来它的下一帧。

在某个训练资料里面,小精灵到了一个特定的转角之后可能学习到的是向左转,但是在另一个训练资料里面,小精灵学习到的可能是向右转。又因为一个训练集里面可能同时包含着两个训练资料,这就导致了机器会出现“两面讨好”的现象,小精灵到了某个转角之后下一帧会同时向左向右分裂。

在这里插入图片描述
因此,当我们对网络加入一个分布的采样之后,对于一个固定的训练资料就能输出不一样的结果。

也就是说,一个固定的输入加上分布 z z z,可以获得更加多样性的输出。

在这里插入图片描述
特别是对于那些需要 “创造力” 的任务来说, 我们就非常需要 generative model 。其中,最著名的一个 generative model 就是 Generative Adversarial Network

Generative Adversarial Network(GAN,生成对抗网络)是一种深度学习模型,它由两个主要部分组成:生成器(Generator)和判别器(Discriminator)。GANs的设计目的是为了让生成器学会模仿真实数据的分布,从而能够生成与训练数据相似的新样本。这两个组件通过“对抗”的方式共同进化,最终达到一种平衡状态,在这种状态下生成的数据难以被区分真假。

在这里插入图片描述
以上图为例,假设先不考虑输入 x x x ,输入到 generator 中的应该是从一个正态分布中采样的一个低维向量,通过 generator 之后将会产生一个代表不同图片的高维向量。

1.2.Discriminator

在这里插入图片描述
在 GAN 中还需要一个组件 discriminator (判别器),这个判别器也是一个神经网络,我们可以将生成器产生的图片输入判别器中生成一个数值(scalar),这个数值越大说明越接近真实值,越小说明越远离真实值。

1.3.Basic Idea of GAN

在这里插入图片描述
在生成二次元图像的例子中,第一代的生成器生成图像之后会交给第一代的判别器进行判别,判别器对比真实图像和生成图像的差距对生成图像进行鉴别。在训练过程中,生成器和判别器同时更新它们的参数并进行迭代进化。

生成器试图最大化判别器错误分类其生成样本的概率,即让判别器将其生成的数据识别为真实数据。相反,判别器试图最小化错误分类的概率,即准确地区分真实数据和生成数据。这种训练过程可以看作是一种“零和游戏”,其中生成器和判别器互相对抗,直到生成器能够生成几乎不可辨别的虚假数据为止。

1.4.Algorithm

1.4.1.Step 1:Fix generator G, and update discriminator D

在这里插入图片描述
第一步,我们应该先固定生成器 G 的参数,然后更新判别器 D 的参数。

例如上图,我们将随机采样的向量丢到生成器里产生四张图片,然后再从数据集里采样四张图片,最后训练判别器来区分真实图片和生成出来的图片。这个任务既可以看做是分类问题也可以看做是回归问题,即可以用两种方法来玩法。

1.4.2.Step 2:Fix discriminator D, and update generator G

在这里插入图片描述

第二步,我们应该固定判别器 D 的参数,然后更新生成器 G 的参数。

我们在这个阶段的任务是让生成器去“骗过”判别器。我们将一个采样出来的向量通过这个 GAN 网络之后会得到一个分数,此时我们通过调整生成器的参数(不调整判别器的参数)来使得我们最后得出来的分数越大越好。这个阶段的训练和之前学习的没什么不同,可以采用梯度下降法来求解。

在这里插入图片描述
最后就是反复进行这个过程:固定住 G 的参数,更新 D 的参数,然后再固定住 D 的参数,更新 G 的参数,一直迭代运行下去。

2.Theory behind GAN

在这里插入图片描述
我们将一个简单的分布 z z z 传入生成器后会产生一个新的分布 P G P_G PG 。假设真实的数据也是一个分布,令其为 P d a t a P_{data} Pdata。我们的任务就是让 P G P_G PG P d a t a P_{data} Pdata 越接近越好。此时可以定义一个新的函数 G ∗ = a r g   m i n   D i v ( P G , P d a t a ) G^*=arg\ min\ Div(P_G,P_{data}) G=arg min Div(PG,Pdata) ,其中 D i v ( x , y ) Div(x,y) Div(x,y) 为求 x x x y y y 之间的发散度。

在生成对抗网络(GANs)中,“发散度”(Divergence)通常指的是衡量两个概率分布之间的差异或距离。GANs的核心思想是通过生成器(Generator)和判别器(Discriminator)之间的对抗过程,使得生成器能够学习到真实数据分布,并生成与真实数据难以区分的样本。在这个过程中,发散度被用来度量生成数据分布与真实数据分布之间的差距。

对于这个新的函数 G ∗ = a r g   m i n   D i v ( P G , P d a t a ) G^*=arg\ min\ Div(P_G,P_{data}) G=arg min Div(PG,Pdata) 的求解似乎和我们之前学过的 w ∗ , b ∗ = a r g   m i n   L w^*,b^* =arg\ min\ L w,b=arg min L 类似,看起来好像可以应用梯度下降法求未知的参数。

但是,使用梯度下降法直接求解是一个理论上看似简单但实际上面临诸多挑战的问题:

1. 发散度的不可微性
许多常用的发散度(如KL散度、JS散度)在其定义域内并不是处处可微的。这意味着它们在某些点上可能没有定义良好的梯度,从而无法直接应用于梯度下降法。例如,KL散度在分布 P G P_G PG P d a t a P_{data} Pdata 支持不重叠的区域是无穷大的,这会导致梯度不存在或者不连续。

2. 高维空间的挑战
在高维空间中,直接计算和优化发散度非常困难。例如,在图像生成任务中,每个像素可以视为一个维度,使得问题变得极其复杂。直接优化高维空间中的发散度通常需要大量的计算资源,并且容易陷入局部最优。

3. 发散度的估计难度
在实际应用中,我们通常只有有限数量的样本,而不是完整的概率分布。直接估计发散度需要从有限样本中推断整个分布,这本身就是一个挑战。此外,即使能够估计发散度,其估计值也可能存在较大的方差,从而影响优化过程。

在这里插入图片描述

由于我们不清楚 P G P_G PG P d a t a P_{data} Pdata 的分布,我们没有办法算出它们的散度。但是 GAN 的理论又告诉我们,只要我们能从 P G P_G PG P d a t a P_{data} Pdata 中采样出来数据,我们就可以计算出它们的散度。

那么我们是如何计算出它们的散度的呢?

计算散度要依靠 Discriminator 的力量。

在这里插入图片描述
我们想要训练一个判别器,该判别器看到真实的数据就会给一个高分,看到生成器的数据就会给一个低分。

在这里插入图片描述
因此,我们只需要训练一组判别器的参数,使得一个函数 V ( D , G ) V(D,G) V(D,G) 取得最大值即可。这个函数如上图所示,我们可以观察到这个函数与一个二分类交叉熵的公式很像,实际上也可以看做训练一个使得交叉熵最低的二分类器。

二分类交叉熵(Binary Cross-Entropy)公式为: H ( y , y ^ ) = − [ y log ⁡ ( y ^ ) + ( 1 − y ) log ⁡ ( 1 − y ^ ) ] H(y, \hat{y}) = - \left[ y \log(\hat{y}) + (1 - y) \log(1 - \hat{y}) \right] H(y,y^)=[ylog(y^)+(1y)log(1y^)]

为什么交叉熵能用于分类

1.正确分类时损失最小:当模型正确分类时,即 y = 1 y=1 y=1 y ^ \hat{y} y^
接近1,或 y = 0 y=0 y=0 y ^ \hat{y} y^ 接近0,交叉熵损失会很小,因为此时 log ⁡ ( y ^ ) \log(\hat{y}) log(y^) log ⁡ ( 1 − y ^ ) \log(1 - \hat{y}) log(1y^) 接近0。
2.错误分类时损失最大:当模型错误分类时,即 y = 1 y=1 y=1 y ^ \hat{y} y^​接近0,或 y = 0 y=0 y=0 y ^ \hat{y} y^ 接近1,交叉熵损失会很大,因为此时 log ⁡ ( y ^ ) \log(\hat{y}) log(y^) log ⁡ ( 1 − y ^ ) \log(1 - \hat{y}) log(1y^) 会使一个较大的附属,导致损失增大。

在这里插入图片描述
P G P_G PG P d a t a P_{data} Pdata 的散度比较小的时候,判别器很难区分两种不同的数据,因此红框内的数值就小,而当 P G P_G PG P d a t a P_{data} Pdata 的散度比较大的时候,判别器就可以很轻易区分出两种不同的数据,因此此时红框内的数据就大。

在这里插入图片描述
由于 P G P_G PG P d a t a P_{data} Pdata 的散度与 max ⁡ D   V ( D , G ) \max\limits_{D}\ V(D,G) Dmax V(D,G) 相关,实际上公式可以改下为如下:

G ∗ = a r g max ⁡ G max ⁡ D V ( D , G ) G^*=arg\max\limits_{G} \max\limits_{D}V(D,G) G=argGmaxDmaxV(D,G)

3.Tips for GAN

在这里插入图片描述
JS散度并不适合 GAN 。

在大多数情况下, P G P_G PG P d a t a P_{data} Pdata 都是不重合的:

  1. 数据的性质: P G P_G PG P d a t a P_{data} Pdata 是高维空间中的低维向量,重叠部分可以忽略不记。
  2. 采样的特点: 即使 P G P_G PG P d a t a P_{data} Pdata 是重合的,由于我们的数据是采样出来的,只有当我i们采样足够多的时候才会出现数据重叠现象。

在这里插入图片描述
对于两组并不重合的数据来说,JS 散度无论怎么求都总会是 log ⁡ 2 \log2 log2,这就导致了区分不出来不同数据的好坏。

3.1.Wasserstein Distance

在这里插入图片描述

因此,我们需要找另一个量来衡量两个分布之间的差异。

Wasserstein Distance(也称为地球搬运者距离(Earth-Mover’s Distance,EMD))是一种衡量两个概率分布之间差异的方法。它在概率论、统计学和机器学习中被广泛应用,尤其是在生成对抗网络(GANs)中,用于度量两个概率分布之间的距离。Wasserstein距离可以形式化地定义为将一个概率分布变形为另一个概率分布所需要的最小“工作量”。

在这里插入图片描述
把一个分布“推”成另一个分布有很多种方法,我们需要穷举所有的 moving plan,再把平均距离最小的那个方法定义为 Wasserstein Distance。

引入 Wasserstein Distance 的好处显而易见。

在这里插入图片描述
当使用 JS 散度时,我们观察不到训练朝着越来越好的方向前进;而当我们使用 Wasserstein Distance 时,可以在训练过程中看到 W ( P G , P d a t a ) W(P_G , P_{data}) W(PG,Pdata) 逐渐降低。

3.2.WGAN

在这里插入图片描述
当我们使用 Wasserstein Distance 去取代 JS 散度后,GAN 就变成了 WGAN。

在这里插入图片描述

总结

生成对抗网络(GAN)由生成器和判别器两部分组成:生成器学习生成类似真实数据的样本,而判别器则负责区分真实数据与生成的数据。传统的GAN采用最小最大博弈框架,其中生成器试图“欺骗”判别器,使它难以分辨真假,而判别器则努力准确地区分真实数据和生成数据。相比之下,Wasserstein GAN(WGAN)通过使用Wasserstein距离作为损失函数来改进传统GAN,这种方法提供了更好的梯度信息并增强了训练的稳定性。WGAN通过确保判别器(称为批评器)是1-Lipschitz连续的,通常通过权重裁剪或梯度惩罚技术来实现这一点,从而获得更稳定的收敛性和更高的生成质量。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值