GAN 中关于辨别器 detach()函数的作用

detach()函数的作用相信大家已经在网上搜过了,这里再简单叙述一下,照搬一下,

即返回一个新的tensor,从当前计算图中分离下来。但是仍指向原变量的存放位置,不同之处只是requirse_grad为false.得到的这个tensir永远不需要计算器梯度,不具有grad。简单来说detach就是截断反向传播的梯度流。

之前测试了一篇mnist-GAN的代码,代码很简单,最近要用到他。但是如下所示,辨别器D中有个detach()函数,

 D_optimizer.zero_grad() 
        fake = G(z)  
        
        d_fake_res = D(fake.detach()) 
        d_fake_loss = BCE_loss(d_fake_res,torch.zeros_like(d_fake_res)) 
        d_real_res = D(x)  # shape : torch.Size([128,1])
        d_real_loss = BCE_loss(d_real_res,torch.ones_like(d_real_res)) 
        d_loss = (d_fake_loss + d_real_loss)/2

        mean_D_loss += d_loss.item() / display_step
        d_loss.backward()
        D_optimizer.step() 

在结合上网查资料之后我所理解的如下:

用到detach()函数的原因是因为辨别器的损失函数loss是由两部分组成的,其中的一部分还与生成器有关。所以当执行d_loss.backward() 反向传播的时候,会把所有相关权重参数的梯度都给计算到,虽然辨别器的优化器参数给定的只是它自己的网络参数本身,但是这并不妨碍反向传播要计算有关反向传播计算图中的所有权重参数。这样会造成资源的浪费。所以在训练D时要用detach来截断反向传播流,使反向传播只计算完有关辨别器的所有权重参数就不再继续往下进行了。(注意,反向传播时辨别器是在第一位置,然后才传到生成器,即辨别器的输出→ 辨别器→生成器)。

然后训练生成器的时候,由于正向过程中生成器在辨别器之前,反向传播需要从辨别器的输出往回传,所以反向传播时的顺序是  辨别器的输出→ 辨别器→生成器,所以训练生成器时不需要再阻断反向传播的梯度流了。

希望能狗帮助到你,如果我的叙述有误,请留言评论探讨,谢谢帮我改正。加油(ง •_•)ง

  • 15
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

匿名的魔术师

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值