Pytorch中的detach用法

该方法主要应用在Variable变量上,作用是从分离出一个tensor,值和原Variable一样,但是不需要计算梯度。

其源码如下:

def detach(self):
    result = NoGrad()(self)  # this is needed, because it merges version counters
    result._grad_fn = None
    return result

在需要使用A网络的Variable进行B网络中的backprop操作,但又不想更新A网络梯度时,可以使用detach操作

例如pix2pix中使用Generator产生的fake结果需要传入Discriminator进行训练时,不希望更新Generator,可以使用该方法:

def forward(self):
    """Run forward pass; called by both functions <optimize_parameters> and <test>."""
    self.fake_B = self.netG(self.real_A)  # G(A)

def backward_D(self):
    """Calculate GAN loss for the discriminator"""
    # Fake; stop backprop to the generator by detaching fake_B
    fake_AB = torch.cat((self.real_A, self.fake_B), 1)  # we use conditional GANs; we need to feed both input and output to the discriminator
    pred_fake = self.netD(fake_AB.detach())
    self.loss_D_fake = self.criterionGAN(pred_fake, False)
    # Real
    real_AB = torch.cat((self.real_A, self.real_B), 1)
    pred_real = self.netD(real_AB)
    self.loss_D_real = self.criterionGAN(pred_real, True)
    # combine loss and calculate gradients
    self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
    self.loss_D.backward()

参考https://www.cnblogs.com/jiangkejie/p/9981707.html

以及https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/pix2pix_model.py

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值