巧断梯度:单个loss实现GAN模型(附开源代码)

640


作者丨苏剑林

单位丨广州火焰信息科技有限公司

研究方向丨NLP,神经网络

个人主页丨kexue.fm


我们知道普通的模型都是搭好架构,然后定义好 loss,直接扔给优化器训练就行了。但是 GAN 不一样,一般来说它涉及有两个不同的 loss,这两个 loss 需要交替优化。


现在主流的方案是判别器和生成器都按照 1:1 的次数交替训练(各训练一次,必要时可以给两者设置不同的学习率,即 TTUR),交替优化就意味我们需要传入两次数据(从内存传到显存)、执行两次前向传播和反向传播。


如果我们能把这两步合并起来,作为一步去优化,那么肯定能节省时间的,这也就是 GAN 的同步训练。


注:本文不是介绍新的 GAN,而是介绍 GAN 的新写法,这只是一道编程题,不是一道算法题。


如果在TF中


如果是在 TensorFlow 中,实现同步训练并不困难,因为我们定义好了判别器和生成器的训练算子了(假设为 D_solver G_solver ),那么直接执行:


sess.run([D_solver, G_solver], feed_dict={x_in: x_train, z_in: z_train})


就行了。这建立在我们能分别获取判别器和生成器的参数、能直接操作 sess.run 的基础上。


更通用的方法 


但是如果是 Keras 呢?Keras 中已经把流程封装好了,一般来说我们没法去操作得如此精细。


  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值