该方法主要应用在Variable变量上,作用是从分离出一个tensor,值和原Variable一样,但是不需要计算梯度。
其源码如下:
def detach(self):
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None
return result
在需要使用A网络的Variable进行B网络中的backprop操作,但又不想更新A网络梯度时,可以使用detach操作
例如pix2pix中使用Generator产生的fake结果需要传入Discriminator进行训练时,不希望更新Generator,可以使用该方法:
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG(self.real_A) # G(A)
def backward_D(self):
"""Calculate GAN loss for the discriminator"""
# Fake; stop backprop to the generator by detaching fake_B
fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator
pred_fake = self.netD(fake_AB.detach())
self.loss_D_fake = self.criterionGAN(pred_fake, False)
# Real
real_AB = torch.cat((self.real_A, self.real_B), 1)
pred_real = self.netD(real_AB)
self.loss_D_real = self.criterionGAN(pred_real, True)
# combine loss and calculate gradients
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5
self.loss_D.backward()
参考https://www.cnblogs.com/jiangkejie/p/9981707.html
以及https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/pix2pix_model.py