Pytorch: detach 和 retain_graph,和 GAN的原理解析

5 篇文章 1 订阅

转载知乎文章:https://zhuanlan.zhihu.com/p/43843694

转载知乎文章:https://zhuanlan.zhihu.com/p/43843694

写的真的很棒!看的很明白.谢谢作者

本人观察 Pytorch 下的生成对抗网络(GAN)的实现代码,发现不同人的实现细节略有不同,其中用到了 detach 和 retain_graph,本文通过两个 gan 的代码,介绍它们的作用,并分析,不同的更新策略对程序效率的影响。

这两个 GAN 的实现中,其更新策略不同,前者是先更新判别器(discriminator)参数,再更新生成器(generator)参数,这正是原始论文Generative Adversarial Networks 中的算法(下图所示);后者是先更新 generator 参数,再更新 discriminator 参数,它们的实现孰优孰劣呢?

 

GAN 的基本原理

首先回顾一下生成对抗网络算法的基本原理,对于熟悉的同学,可以跳过这部分。限于篇幅,只介绍最原始的两种 GAN 损失函数,不失一般性。

原始 GAN 的损失函数,主要来源于 binary cross-entropy loss:

L(x^{(i)},y^{(i)})=-y^{(i)}\cdot\mathrm{log}P(x^{(i)})-(1-y^{(i)})\cdot\mathrm{log}(1-P(x^{(i)}))\tag{1}

其中,y 为真实标签,取 1 (正样本) 或者 0 (负样本),P(x) 为 x 属于正样本的概率。

判别器有两种输入,一种是真实的样本 (x\in P_r,y=1);一种是生成器通过噪音 z 生成的假样本 (G(z)\in P_g, y=0) 。判别器的工作是将这两种来源的样本区分开,故采用公式 (1) 计算其损失函数 L_d

分别把真实样本和生成的假样本代入公式 (1),对于每个正样本,其损失函数只剩下:

L_{d_{real}}(x^{(i)}\in P_r,y^{(i)}=1)=-\mathrm{log}D(x^{(i)})\tag{2}

而对于每个负样本,损失函数为:

L_{d_{fake}}(G(z^{(i)}),y^{(i)}=0)=-\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{3}

(这里把公式 (1) 的 P(x) 改为 D(x),表示是 discriminator 的输出,和上图中的算法保持一致)。把这两个损失函数加起来,算一下平均值:

L_{d_{total}}=-\frac{1}{m}\sum_{i=1}^{m}\mathrm{log}D(x^{(i)})+\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{4}

然后,将这个损失函数对 discriminator 的参数 \theta_d 求导,即得到上图算法中的梯度:

\frac{\mathrm{d}L_{d_{total}}}{\mathrm{d}\theta_d}=-\nabla_{\theta_d}\frac{1}{m}\sum_{i=1}^{m}\mathrm{log}D(x^{(i)})+\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{5}

注意算法截图中用梯度上升优化,所以和上式差一个负号

生成器输入噪声 z^{(i)} ,输出一个假数据 G(z^{(i)}) 。它希望这个假数据能骗过判别器。而上文判别器对假数据的损失函数为:

L_{d_{fake}}(G(z^{(i)}),y^{(i)}=0)=-\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{3}

这个损失越大,意味着判别器的性能越差,同时也意味着生成器的性能越好。也就是说,我们希望朝着增大该损失的方向来调整生成器的参数。如果采用梯度下降优化生成器,也就是要减小它的相反数,即生成器的损失函数:

L_g(z^{(i)})=\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{6}

这就是原始 GAN 生成器的第一种损失函数,对其取平均再对生成器参数 \theta_g 求梯度,得到算法截图里面的公式:

\frac{\mathrm{d}L_{g_{total}}}{\mathrm{d}\theta_g}=\nabla_{\theta_g}\frac{1}{m}\sum_{i=1}^m\mathrm{log}\left(1-D(G(z^{(i)}))\right)\tag{7}

算法截图中用梯度下降优化,所以和上式相同

但是,公式 (6) 的损失函数在训练初期梯度会特别小。因为刚开始训练时,生成器生成的数据十分不真实,导致判别器很容易判断出来它是假数据。因此 D(G(z^{(i)})) 是个接近于0 的常数。而损失函数 (6) 对 generator 参数 \theta_g 的导数为:

\frac{\mathrm{d}L_g(z^{(i)})}{\mathrm{d}\theta_g}=-\frac{1}{1-D(G(z^{(i)}))}\cdot\frac{\mathrm{d}D(G(z^{(i)}))}{\mathrm{d}G(z^{(i)})}\cdot\frac{\mathrm{d}G(z^{(i)})}{\mathrm{d}\theta_g}\tag{8}

由于 D(G(z^{(i)})) 是接近于0的常数,所以 (8) 中等号右边第一项近似为 -1,而中间项为接近 0 的数(因为 D(G(z^{(i)})) 是接近 0 的常函数,而常函数的梯度为 0)。

因此,GAN 的作者又给出了第二种生成器损失函数:

L_g(z^{(i)})=-\mathrm{log}(D(G(z^{(i)})))\tag{9}

这个损失函数很好理解,就是要生成这样的数据:使得输入 discriminator 后,输出一个较大的概率,即对生成器来说 \mathrm{log}(D(G(z^{(i)}))) 越大越好,取它的相反数,也就是公式 (9) 越小越好。我们计算公式 (9) 的损失函数对生成器参数 \theta_g 的导数:

\frac{\mathrm{d}L_g(z^{(i)})}{\mathrm{d}\theta_g}=-\frac{1}{D(G(z^{(i)}))}\cdot\frac{\mathrm{d}D(G(z^{(i)}))}{\mathrm{d}G(z^{(i)})}\cdot\frac{\mathrm{d}G(z^{(i)})}{\mathrm{d}\theta_g}\tag{10}

和 (8) 比只有等号右边第一项不同,当 D(G(z^{(i)})) 接近 0 时, 1/D(G(z^{(i)})) 接近无穷大,和中间那项相乘,一定程度抵消了中间项接近 0 的问题,使计算出来的梯度增大了。

在实现过程中,基本上都是采用公式 (9) 作为生成器损失函数,因为这十分方便,只要在计算生成器损失函数时,把虚假数据的标签标记为真,代入 Binary cross-entropy loss 的公式即可。

更新策略

下面进入本文正题,即,在 pytorch 中,detach 和 retain_graph 是干什么用的?本文将借助两段 GAN 的实现代码,来举例介绍它们的作用。

第一段代码先更新判别器,再更新生成器。我们分析循环中一个 step 的代码:

valid = torch.Tensor(imgs.size(0), 1).fill_(1.0).to(device) # 真实标签,都是1
fake = torch.Tensor(imgs.size(0), 1).fill_(0.0).to(device) # 假标签,都是0
#----------
# 训练判别器
#----------
real_imgs = imgs.to(device)
z = torch.randn((imgs.shape[0], 100)).to(device) # 噪声
gen_imgs = generator(z) # 从噪声中生成假数据
pred_gen = discriminator(gen_imgs) # 判别器对假数据的输出
pred_real = discriminator(real_imgs) # 判别器对真数据的输出
optimizer_D.zero_grad() # 把判别器中所有参数的梯度归零
real_loss = adversarial_loss(pred_real, valid) # 判别器对真实样本的损失
fake_loss = adversarial_loss(pred_gen, fake) # 判别器对假样本的损失
d_loss = (real_loss + fake_loss) / 2  # 两项损失相加取平均
# 下面这行代码十分重要,将在正文着重讲解
d_loss.backward(retain_graph=True) # retain_graph 十分重要,否则计算图内存将会被释放
optimizer_D.step() # 判别器参数更新
#---------
#训练生成器
#---------
g_loss = adversarial_loss(pred_gen, valid) # 生成器的损失函数
optimizer_G.zero_grad() # 生成器参数梯度归零
g_loss.backward() # 生成器的损失函数梯度反向传播
optimizer_G.step() # 生成器参数更新

上面的代码中 d_loss.backward(retain_graph=True) 十分关键,它用于反向传播 discriminator loss 的梯度。那么,具体传播到什么地方去呢?

这要看 d_loss 是由哪几部分构成的:real_loss 和 fake_loss,而 fake_loss 又是从 noise 经过 generator 来的。因此,d_loss 反向传播,将一传到底,不但计算了 discriminator 的梯度,同时还计算了 generator 的梯度,虽然这一步optimizer_D.step()只更新 discriminator 的参数。

也正是这个原因,下面在更新 generator 参数时,要先将生成器参数的梯度重新归零,避免受到 discriminator loss 回传过来的梯度影响。注意:它在反向传播时,设置了 retain graph = 0 这个参数。它的作用是保持计算图,因为 pytorch 默认一个计算图只计算一次反向传播,反向传播后,这个计算图的内存就被释放了。而后面的 generator 算梯度时还要用到这个计算图,所以用这个参数控制计算图不被释放。

generator 的 损失在回传时,同样要经过 discriminator 网络才能传递回自身(系统从输入噪声到 Discriminator 输出,从头到尾只有一次前向传播,而有两次反向传播,故在第一次反向传播时,要保持计算图不被释放)。因此,在回传梯度时,同样也计算了一遍 discriminator 的参数梯度,只不过这次 discriminator 的参数不更新,只更新 generator 的参数,即 optimizer_G.step()。同时,我们看到,下一个 step 首先将 discriminator 的梯度重置为 0,就是为了防止 generator loss 反向传播时顺带计算的梯度对其造成影响(还有上一步 discriminator loss 回传时累积的梯度)。

综上,我们看到,为了完成一步参数更新,我们进行了两次反向传播,第一次反向传播为了更新 discriminator 的参数,但多余计算了 generator 的梯度。第二次反向传播为了更新 generator 的参数,但是不得不多计算 discriminator 的梯度。

 

对于先更新生成器参数的情况,我们也分析其循环中一个 step 的代码:

valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False) # 真实样本的标签,都是 1
fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False) # 生成样本的标签,都是 0
real_imgs = Variable(imgs.type(Tensor))
#-----------
# 训练生成器
#-----------
optimizer_G.zero_grad() # 生成器参数梯度归零
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim)))) # 噪声
gen_imgs = generator(z) # 根据噪声生成虚假样本
g_loss = adversarial_loss(discriminator(gen_imgs), valid) # 用真实的标签+假样本,计算生成器损失
g_loss.backward() # 生成器梯度反向传播,反向传播经过了判别器,故此时判别器参数也有梯度
optimizer_G.step() # 生成器参数更新,判别器参数虽然有梯度,但是这一步不能更新判别器
#----------
# 训练判别器
#----------
optimizer_D.zero_grad() # 把生成器损失函数梯度反向传播时,顺带计算的判别器参数梯度清空
real_loss = adversarial_loss(discriminator(real_imgs), valid) # 真样本+真标签:判别器损失
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) # 假样本+假标签:判别器损失
d_loss = (real_loss + fake_loss) / 2  # 判别器总的损失函数
d_loss.backward() # 判别器损失回传
optimizer_D.step() # 判别器参数更新

上述代码先更新生成器参数,再更新判别器参数。那么除了顺序上的区别,和先更新判别器,再更新生成器,还有什么不同?答案是,计算图的遍历次数不同。

为了更新生成器参数,用生成器的损失函数计算梯度,然后反向传播,传播图中经过了discriminator,根据链式法则,不得不顺带计算一下判别器的参数梯度,虽然在这一步不会更新判别器参数。反向传播过后,noise 到 fake image 再到 discriminator 的输出这个前向传播的计算图就被释放掉了,后面也不会再用到。

接着更新判别器参数,此时注意到,我们输入判别器的是两部分,一部分是真实数据,另一部分是生成器的输出,也就是假数据。注意观察细节,在判别器前向传播过程,输入的假数据被 detach 了:discriminator(gen_imgs.detach()),detach 的意思是,这个数据和生成它的计算图“脱钩”了,即梯度传到它那个地方就停了,不再继续往前传播(实际上也不会再往前传播了,因为 generator 的计算图在第一次反向传播过后就被释放了)。因此,判别器梯度反向传播,就到它自己身上为止。

因此,比起第一种策略,要少计算一次 generator 的所有参数的梯度,同时,也不必刻意保存一次计算图,占用不必要的内存。

但需要注意的是,在第一种策略中,noise 从 generator 输入,到 discriminator 输出,只经历了一次前向传播,discriminator 端的输出,被用了两次,一次是计算 discriminator 的损失函数,另一次是计算 generator 的损失函数。

而在第二种策略中,noise 从 generator 输入,到discriminator 输出,计算 generator 损失,回传,这一步更新了 generator 的参数,并释放了计算图。下一步更新 discriminator 的参数时,generator 的输出经过 detach 后,又通过了一遍 discriminator,相当于,generator 的输出前后两次通过了 discriminator ,得到相同的输出。显然,这也是冗余的。

 

综上,这两段代码各有利弊:

第一段代码,好处是 noise 只进行了一次前向传播,缺点是,更新 discriminator 参数时,多计算了一次 generator 的梯度,同时,第一次更新 discriminator 需要保留计算图,保证算 generator loss 时计算图不被销毁。

第二段代码,好处是通过先更新 generator ,使更新后的前向传播计算图可以放心被销毁,因此不用保留计算图占用内存。同时,在更新 discriminator 的时候,也不会像上面的那段代码,计算冗余的 generator 的梯度。缺点是,在 discriminator 上,对 generator 的输出算了两次前向传播,第二次又产生了新的计算图(但比第一次的小)。

一个多计算了一次 generator 梯度,一个多计算一次 discriminator 前向传播。因此,两者差别不大。如果 discriminator 比generator 复杂,那么应该采取第一种策略,如果 discriminator 比 generator 简单,那么应该采取第二种策略,通常情况下,discriminator 要比 generator 简单,故应该采取第二种策略居多。

 

但是第二种先更新generator,再更新 discriminator 总是给人感觉怪怪得,因为 generator 的更新需要 discriminator 提供准确的 loss 和 gradient,否则岂不是在瞎更新?

 

还有一种没提到的策略,noise 从 generator 输入,输出 fake data,然后 detach 一下,随着 true data 一起输入 discriminator,计算 discriminator 损失,并更新 discriminator 参数。接下来,再把没经过 detach 的 fake data 输入到discriminator 中,计算 generator loss,再反向传播梯度,更新 generator 的参数。这种策略,计算了两次 discriminator 梯度,一次 generator 梯度。感觉这种比较符合先更新 discriminator 的习惯。缺点是,之前的 generator 生成的计算图得保留着,直到 discriminator 更新完,再释放。不像策略二,马上用完马上释放。综合来说,还是策略二最好,策略三其次,策略一最差(差在多计算一次 generator gradient 上,而通常多计算一次 generator gradient 的运算量比多计算一次 discriminator 前向传播的运算量大),因此,detach 还是很有必要的。

  • 7
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值