pytorch 生成全零tensor 指定type_GAN的快速理解以及Pytorch实现

c1dae20de43760ab4421d28747fa76a7.png

原论文地址:https://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf

GitHub:https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/gan/gan.py

一、GAN 有什么用?

GAN 即 Generative Adversarial Nets,生成对抗网络,从名字上我们可以得到两个信息:

  1. 首先,它是一个生成模型
  2. 其次,它的训练是通过“对抗”完成的

何为生成模型?即,给个服从某种分布(比如正态分布)随机数,模型就可以给你生成一张人脸、一段文字 etc。它要做的,是找到一种映射关系,将随机数的分布映射成数据的分布。

何为对抗?GAN 除了包含一个生成模型 G 外还包含一个 判别模型 D ,G 输入随机数生成数据,D 输入数据输出置信度,1 表示是真实数据,0 表示为 G 伪造的数据;二者通过反复地对抗,最终理想情况下, G 生成的数据与真实数据非常接近,分布也相同,而 D 无论输出真实数据还是 G 伪造的数据都输出0.5。

二、GAN 的目标函数及流程

c774f149956c235f41ce5662e2fa0991.png
  • max 部分的含义是,D 要尽可能正确地识别出真实数据和 G 伪造的数据。
  • min 部分的含义是,G 要尽可能缩小自己生成的数据与真实数据的差别,让 D 真假难别。

整个训练流程如图:

9742cc9c9d16f359ad960c64fb56872b.png

在每一步的训练中:

  1. 取 m 个真实数据,使用 G 和 m 组随机数(一般使用服从正态分布的随机数)生成 m 个假数据
  2. 根据 max 部分的目标更新 D 的参数,提高 D 的分辨能力
  3. 根据 min 部分的目标更新 G 的参数,使 G 生成的数据更有迷惑性

三、GAN 的 Pytorch 实现(使用 mnist 数据集)

import 

在这个实现中需要注意的一点是,原论文中 G 的训练是希望减小 log(1-D(G(z)),而代码中是使用二值交叉熵BCE(G(z), 1),即希望提高-log(D(G(x))),虽然都是希望让 D(G(x)) 趋近于1 ,但数值上还是有细微的不同。

PS:

广告时间啦~

理工狗不想被人文素养拖后腿?不妨关注微信公众号:

b792d92b56ee4bdf24545760294e307a.png
欢迎扫码关注~
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值