对抗生成网络(GAN)

本文介绍了对抗生成网络(GAN)的概念,它通过训练一个生成模型G和一个区分模型D,使得G能生成与真实数据分布相似的样本。GAN在训练过程中,两者相互博弈,逐渐提升生成样本的质量。文章讨论了GAN的优缺点,以及在训练过程中的收敛性和全局最优解,并提供了相关实验结果和资源链接。
摘要由CSDN通过智能技术生成

对抗生成网络(GAN)

摘要:
我们提出了一个通过对抗过程来估计生成模型的新框架,在该框架中,我们同时训练了两个模型:捕获数据分布的生成模型G和估计样本来自训练数据的概率的区分模型D, G的训练过程是使D犯错的可能性最大化。 该框架对应于minimax两人游戏。 在任意函数G和D的空间中,存在唯一的解决方案,其中G恢复训练数据分布,并且D各处都等于1/2。 在G和D由多层感知器定义的情况下,整个系统可以通过反向传播进行训练。 在训练或样本生成期间,不需要任何马尔可夫链或展开的近似推理网络。 实验通过对生成的样本进行定性和定量评估,证明了该框架的潜力。

Introdaction:
到目前为止,深度学习中最显着的成功涉及判别模型,通常是那些将高维,丰富的感官输入映射到分类标签的模型[14,22]。深度生成模型的影响较小,这是由于难以估计在最大似然估计和相关策略中出现的许多棘手的概率计算,并且由于难以在生成上下文中利用分段线性单位的优势。

在提出的对抗网络框架中,生成模型与一个对手相对立:一个判别模型,该模型学习确定样本是来自模型分布还是来自数据分布。 生成模型可以被认为类似于一组伪造者,试图生产假币并在未经检测的情况下使用它,而区分模型类似于警察,试图发现伪币。 在这场比赛中,比赛迫使两支球队都改进自己的方法,直到假冒伪劣品与真品无法区分为止。

Related work
具有潜在变量的有向图形模型的替代方法是具有潜在变量的无向图形模型,例如受限的Boltzmann机器(RBM)[27、16],深Boltzmann机器(DBM)[26]及其众多变体。 此类模型中的交互表示为未归一化的潜在函数的乘积,该函数通过对随机变量所有状态的全局求和/积分来归一化。尽管可以通过马尔可夫链蒙特卡洛(MCMC)方法进行估算,但对于最琐碎的实例而言,此数量(分区函数)及其梯度对于所有实例而言都是棘手的。 对于依赖MCMC的学习算法,混合提出了一个重要的问题[3,5]。

深度信念网络(DBN)[16]是包含单个无向层和多个有向层的混合模型。 虽然存在快速的近似逐层训练准则,但DBN会引起与无向和有向模型相关的计算困难。

还提出了不近似或限制对数可能性的替代标准,例如得分匹配[18]和噪声对比估计(NCE)[13]。这两种方法都需要对所获知的概率密度进行解析指定,直至归一化常数。请注意,在许多有趣的具有几层潜在变量的生成模型(例如DBN和DBM)中,甚至不可能得出可控的未归一化概率密度。一些模型,例如去噪自动编码器[30]和压缩自动编码器,都具有与应用于RBM的得分匹配非常相似的学习规则。在NCE中,就像在这项工作中一样,采用判别性训练标准来适应生成模型。但是,生成模型本身不是用来拟合单独的判别模型,而是用于从样本中以固定的噪声分布来区分生成的数据。由于NCE使用固定的噪声分布,因此在模型学习到一小部分观察变量后,即使学习到近似正确的分布,学习也会大大减慢。

最后,某些技术不涉及显式定义概率分布,而是训练生成机从所需分布中提取样本。 这种方法的优势在于,可以将此类机器设计为通过反向传播进行训练。该领域最近的杰出工作包括生成随机网络(GSN)框架[5],该框架扩展了广义降噪自动编码器[4]:两者都可以看作是定义了参数化的马尔可夫链,即,人们可以学习一台机器的参数。执行生成马尔可夫链的一步。与GSN相比,对抗网络框架不需要马尔可夫链进行采样。由于对抗网络在生成过程中不需要反馈回路,因此它们能够更好地利用分段线性单元[19、9、10],这可以提高反向传播的性能,但在反馈回路中使用时,存在无限激活的问题。 通过反向传播对生成机器进行训练的最新例子包括最近对变分贝叶斯[20]和随机反向传播[24]进行自动编码的工作。

Adversarial nets:
生成器在数据x上的分布pg,先验的输入噪声变量pz(z),G(z;θg)表示到数据空间的映射,其中G关于参数θg可微 。 判别器D(x;θd),它输出一个标量。 D(x)表示x来自真实数据而非pg的概率。 我们训练D来最大化为G训练样本和样本分配正确标签的可能性。我们同时训练G来最小化log(1- D(G(z))):
在这里插入图片描述
在优化D的k个步骤和优化G的一个步骤之间交替进行。只要G的变化足够缓慢,就可以使D保持在其最佳解附近。该过程在算法1中正式提出。
在这里插入图片描述
图1:通过同时更新判别分布(D,蓝色,虚线)来训练生成对抗网络,以便区分生成数据的分布(黑色,虚线)px的样本与生成分布pg(G)的样本之间的区别(绿色实线)。下部水平线是从中采样z的域,在这种情况下是均匀采样的。上面的水平线是x的域的一部分。向上的箭头表示映射x = G(z)如何将非均匀分布pg施加到转换后的样本上。 G在高密度区域收缩,在pg低密度区域膨胀。 (a)考虑一个接近收敛的对抗对:pg类似于pdata,D是部分准确的分类器。 (b)在算法的内部循环中,训练D来区分数据中的样本,收敛到D*(x)= pdata(x)/(pdata(x)+ pg(x))。 (c)在更新G之后,D的坡度已引导G(z)流向更可能被归类为数据的区域。 (d)经过几个步骤的训练,如果G和D具有足够的能力,则它们将达到不能提高的点,因为pg = pdata。鉴别符无法区分两个分布,即D(x)= 1/2。

Theoretical Results:
生成器G隐含地定义概率分布pg作为当z〜pz时获得的样本G(z)的分布。 因此,如果有足够的容量和训练时间,我们希望算法1收敛到pdata的一个好的估计量。在这里插入图片描述

  • Global Optimality of pg = pdata
  • Proposition 1.对于固定的G,最佳判别器D为在这里插入图片描述
  • Theorem 1.当且仅当pg = pdata时,才能达到虚拟训练准则C(G)的全局最小值。 此时,C(G)达到值− log 4。
  • Convergence of Algorithm 1
  • Proposition 2.如果G和D具有足够的容量,并且在算法1的每个步骤中,允许鉴别器达到其最佳给定G,并更新pg以改进准则在这里插入图片描述
    然后pg收敛到pdata

Advantages and disadvantages
缺点主要是没有明确表示pg(x),并且在训练过程中D必须与G很好地同步(特别是,在不更新D的情况下G不能训练太多,以避免出现“ Helvetica场景” (其中G将太多的z值折叠为相同的x值,以至于没有足够的多样性来对pdata进行建模),就像必须在学习步骤之间保持Boltzmann机器的负链更新一样。

优点是不再需要马尔可夫链,仅使用backprop即可获得梯度,在学习过程中无需进行推理,并且可以将多种功能集成到模型中。

对抗模型还可以从生成器网络中获得一些统计上的优势,该生成器网络不直接使用数据示例进行更新,而仅使用流经鉴别器的梯度进行更新。 这意味着输入的组成部分不会直接复制到生成器的参数中。 对抗网络的另一个优点是它们可以表示非常尖锐的分布,甚至可以是简并的分布,而基于马尔可夫链的方法要求分布有些模糊,以便链能够在模式之间进行混合。

import torch
import argparse
import os
import numpy as np
import math
import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs('images', exist_ok=True)

parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type<
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值