深度学习TF—14.WGAN原理及实战

本文详细介绍了WGAN( Wasserstein GAN)的原理,包括JS散度的缺陷、Wasserstein距离的引入及其优势,以及WGAN的损失函数。WGAN解决了传统GAN训练不稳定的问题,通过使用Wasserstein距离,能够提供有效的梯度信息,提高了训练的稳定性。此外,文章还简要概述了WGAN的实战步骤,包括数据集加载、网络构建和完整代码。
摘要由CSDN通过智能技术生成

一、WGAN原理

  WGAN 算法从理论层面分析了GAN 训练不稳定的原因,并提出了有效的解决方法。那么是什么原因导致了GAN 训练如此不稳定呢?WGAN 提出是因为JS 散度在不重叠的分布𝑝和𝑞上的梯度曲面是恒定为0 的。当分布𝑝和𝑞不重叠时,JS 散度的梯度值始终为0,从而导致此时GAN 的训练出现梯度弥散现象,参数长时间得不到更新,网络无法收敛。

1.JS散度的缺陷

  下面通过一个简单的分布实例来解释JS 散度的缺陷。考虑完全不重叠(𝜃 ≠ 0)的两个分布𝑝和𝑞

分布𝑝为:∀(𝑥, 𝑦) ∈ p, 𝑥 = 0, 𝑦 ∼ U(0,1)
分布𝑞为:∀(𝑥, 𝑦) ∈ 𝑞, 𝑥 = 𝜃, 𝑦 ∼ U(0,1)
其中𝜃 ∈ 𝑅,当𝜃 = 0时,分布𝑝和𝑞重叠,两者相等;当𝜃 ≠ 0时,分布𝑝和𝑞不重叠。

在这里插入图片描述
我们分析上述分布𝑝和𝑞之间的JS 散度随𝜃的变化情况。
根据KL 散度与JS 散度的定义,计算𝜃 = 0时的JS 散度𝐷𝐽𝑆(𝑝||𝑞):
在这里插入图片描述
当𝜃 = 0时,两个分布完全重叠,此时的JS 散度和KL 散度都取得最小值0
在这里插入图片描述
𝐷𝐽𝑆(𝑝||𝑞)随𝜃的变化趋势为:
在这里插入图片描述
也就是说:当两个分布完全不重叠时,无论分布之间的距离远近,JS 散度为恒定值log2,此时JS 散度将无法产生有效的梯度信息;当两个分布出现重叠时,JS 散度才会平滑变动,产生有效梯度信息;当完全重合后,JS 散度取得最小值0。
在这里插入图片描述

由图可知:由于p分布与q分布不重叠,随着q的移动,生成样本位置处的梯度值始终为0,无法更新生成网络的参数,从而出现网络训练困难的现象。因此,JS 散度在分布𝑝和𝑞不重叠时是无法平滑地衡量分布之间的距离,从而导致此位置上无法产生有效梯度信息,出现GAN 训练不稳定的情况。要解决此问题,需要使用一种更好的分布距离衡量标准,使得它即使在分布𝑝和𝑞不重叠时,也能平滑反映分布之间的真实距离变化。

2.Wasserstein 距离

  WGAN 论文发现了JS 散度导致GAN 训练不稳定的问题,并引入了一种新的分布距离度量方法:Wasserstein 距离,它表示了从一个分布变换到另一个分布的最小代价,定义为:
在这里插入图片描述
其中Π(𝑝, 𝑞)是分布𝑝和𝑞组合起来的所有可能的联合分布的集合,对于每个可能的联合分布𝛾 ∼ Π(𝑝, 𝑞),计算距离‖𝑥 − 𝑦‖的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],其中(𝑥, 𝑦)采样自联合分布𝛾。不同的联合分布𝛾有不同的期望𝔼(𝑥,𝑦)∼𝛾[‖𝑥 − 𝑦‖],这些期望中的下确界即定义为分布𝑝和𝑞的Wasserstein 距离。
  绘制出 JS 散度和EM 距离的曲线,如图所示,可以看到,JS 散度在𝜃 = 0处不连续,其他位置导数均为0,而EM 距离总能够产生有效的导数信息,因此EM 距离相对于JS 散度更适合指导GAN 网络的训练。
在这里插入图片描述

3.损失函数

在这里插入图片描述

前面是EM距离,后面是GP惩罚项

其中𝒙̂来自于𝒙𝑟与𝒙𝑓的线性差值:

𝑥̂ = 𝑡𝒙𝑟 + (1 − 𝑡)𝒙𝑓 , 𝑡 ∈ [0,1]
  判别器 D 的目标是最小化上述的误差ℒ(𝐺, 𝐷),即迫使生成器G 的分布𝑝𝑔与真实分布𝑝𝑟之间EM 距离𝔼𝒙𝑟∼𝑝𝑟[𝐷(𝒙𝑟)]−𝔼𝒙𝑓∼𝑝𝑔 [𝐷(𝒙𝑓)]项尽可能小,‖𝛻𝒙̂𝐷(𝒙̂)‖2逼近于1。
  WGAN 的生成器G 的训练目标为:
在这里插入图片描述
即使得生成器的分布𝑝𝑔与真实分布𝑝𝑟之间的EM 距离越小越好。考虑到𝔼𝒙𝑟∼𝑝𝑟[𝐷(𝒙𝑟)]一项与生成器无关,因此生成器的训练目标简写为:
在这里插入图片描述
  从实现来看,判别网络D 的输出不需要添加Sigmoid 激活函数,这是因为原始版本的判别器的功能是作为二分类网络,添加Sigmoid 函数获得类别的概率;而WGAN 中判别器作为EM 距离的度量网络,其目标是衡量生成网络的分布𝑝𝑔和真实分布𝑝𝑟之间的EM 距离,属于实数空间,因此不需要添加Sigmoid 激活函数。在误差函数计算时,WGAN 也没有log 函数存在。在训练WGAN 时,WGAN 作者推荐使用RMSProp 或SGD 等不带动量的优化器。

  WGAN 从理论层面发现了原始GAN 容易出现训练不稳定的原因,并给出了一种新的距离度量标准和工程实现解决方案,取得了较好的效果。WGAN 还在一定程度上缓解了模式崩塌的问题,使用WGAN 的模型不容易出现模式崩塌的现象。需要注意的是,WGAN一般并不能提升模型的生成效果,仅仅是保证了模型训练的稳定性。当然,保证模型能够稳定地训练也是取得良好效果的前提。

二、WGAN实战

1.数据集的加载
# 加载数据集的函数
import multiprocessing

import tensorflow as tf


def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1
        return img

    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size

    return dataset, img_shape, len_dataset


def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):
    # set defaults
    if n_map_threads is None:
        n_map_threads = multiprocessing.cpu_count()
    if shuffle and shuffle_buffer_size is None:
        shuffle_buffer_size = max(batch_size * 128, 2048)  # set the minimum buffer size as 2048

    # [*] it is efficient to conduct `shuffle` before `map`/`filter` because `map`/`filter` is sometimes costly
    if shuffle:
        dataset = dataset.shuffle(shuffle_buffer_size)

    if not filter_after_map:
        if filter_fn:
            dataset = dataset.filter(filter_fn)

        if map_fn:
            dataset = dataset.map(map_fn, num_parallel_calls=n_map_threads)

    else:  # [*] this is slower
        if map_fn:
            dataset = dataset.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值