Pytorch入门学习(九)---detach()的作用(从GAN代码分析)

转自:点击打开链接

总说

简单来说detach就是截断反向传播的梯度流

    def detach(self):
        """Returns a new Variable, detached from the current graph.

        Result will never require gradient. If the input is volatile, the output
        will be volatile too.

        .. note::

          Returned Variable uses the same data tensor, as the original one, and
          in-place modifications on either of them will be seen, and may trigger
          errors in correctness checks.
        """
        result = NoGrad()(self)  # this is needed, because it merges version counters
        result._grad_fn = None
        return result
  • 可以看到Returns a new Variable, detached from the current graph。将某个node变成不需要梯度的Varibale。因此当反向传播经过这个node时,梯度就不会从这个node往前面传播。

从GAN的代码中看detach()

GAN的G的更新,主要是GAN loss。就是G生成的fake图让D来判别,得到的损失,计算梯度进行反传。这个梯度只能影响G,不能影响D!可以看到,由于torch是非自动求导的,每一层的梯度的计算必须用net:backward才能计算gradInput和网络中的参数的梯度。这个GAN的代码截取自Image-to-Image Translation with Conditional Adversarial Networks.

先看Torch版本的代码

local fGx = function(x)
    netD:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)
    netG:apply(function(m) if torch.type(m):find('Convolution') then m.bias:zero() end end)

    gradParametersG:zero()

    -- GAN loss
    local df_dg = torch.zeros(fake_B:size())
    if opt.use_GAN==1 then
       local output = netD.output -- netD:forward{input_A,input_B} was already executed in fDx, so save computation
       local label = torch.FloatTensor(output:size()):fill(real_label) -- fake labels are real for generator cost

       errG = criterion:forward(output, label)
       local df_do = criterion:backward(output, label)
       df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
    else
        errG = 0
    end

    -- unary loss
    -- 得到 df_do_AE(已省略)   
    netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

    return errG, gradParametersG
end

在下面代码中,是先得到fake图进入D的loss,然后这个loss的梯度df_do进行反传,首先要这个梯度经过D。此时不能改变D的参数的梯度,所以这里用updateGradInput,不能用backward。这是因为backward是调用2个函数updateGradInputaccGradParameters。后者是计算loss对于网络中参数的梯度,这些梯度是不断累加的!除非手动gradParametersG:zero()置零。

       errG = criterion:forward(output, label)
       local df_do = criterion:backward(output, label)
       df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
       -- unary loss
       -- 得到 df_do_AE(已省略)   
       netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

然后得到的df_dg才是要更新G的GAN损失的梯度,当然G的另一个损失是L1损失(unary loss)这个没啥好说了。

pytorch的GAN实现

由于Pytorch是自动反向传播,

    def backward_D(self):
        # Fake
        # stop backprop to the generator by detaching fake_B        fake_AB=self.fake_B
        #fake_AB = self.fake_AB_pool.query(torch.cat((self.real_A, self.fake_B), 1))
        self.pred_fake = self.netD.forward(fake_AB.detach())#只要有forward()函数默认进行反向传播,故要通过
        self.loss_D_fake = self.criterionGAN(self.pred_fake, False)#fake_AB.detaach()进行反向截断,不让梯度
                                                                   #流到G中,只反传到最初输入节点即可。
        # Real
        real_AB = self.real_B # GroundTruth
        # real_AB = torch.cat((self.real_A, self.real_B), 1)
        self.pred_real = self.netD.forward(real_AB)
        self.loss_D_real = self.criterionGAN(self.pred_real, True)

        # Combined loss
        self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5

        self.loss_D.backward()

    def backward_G(self):
        # First, G(A) should fake the discriminator
        fake_AB = self.fake_B
        pred_fake = self.netD.forward(fake_AB)
        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1

        self.loss_G.backward()


    def forward(self):
        self.real_A = Variable(self.input_A)
        self.fake_B = self.netG.forward(self.real_A)
        self.real_B = Variable(self.input_B)

    # 先调用 forward, 再 D backward, 更新D之后; 再G backward, 再更新G
    def optimize_parameters(self):
        self.forward()

        self.optimizer_D.zero_grad()
        self.backward_D()
        self.optimizer_D.step()

        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

解释backward_D:

对于D,如果输入是真实图,要将其与1进行比较,产生loss,输入fake图,D输出结果与0比较,也产生loss。 
通过这两个梯度更新D。如果是真实图(real_B),由于real_B是初始结点,所以没什么可担心的。但是对于生成图fake_B,由于 fake_B是由 netG.forward(real_A)产生的。我们只希望 该loss更新D不要影响到 G. 因此这里需要“截断反传的梯度流”,用 fake_AB = fake_B, fake_AB.detach()从而让梯度不要通过 fake_AB反传到netG中!

解释backward_G:

由于在调用 backward_G已经调用了zero_grad,所以没什么好担心的。 
更新G时,来自D的GAN损失是, netD.forward(fake_AB),得到 pred_fake,然后得到损失,反传播即可。 
注意,这里反向传播时,会先将梯度传到 fake_AB结点,然而我们知道 fake_AB即 fake_B结点,而fake_B正是由netG(real_A)产生的,所以还会顺着继续往前传播,从而得到G的对应的梯度。

对比 Torch代码

df_dg = netD:updateGradInput(fake_AB, df_do):narrow(2,fake_AB:size(2)-output_nc+1, output_nc)
netG:backward(real_A, df_dg + df_do_AE:mul(opt.lambda))

Torch中,没有计算netD的参数的梯度,而是直接用 updateGradInput。在pytorch中,我们也是希望GAN loss只能更新G。但是pytorch是自动求导的,所以我们没法手动像Torch一样只调用updateGradInput

        self.loss_G_GAN = self.criterionGAN(pred_fake, True)

        # Second, G(A) = B
        self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_A

        self.loss_G = self.loss_G_GAN + self.loss_G_L1
        self.loss_G.backward()

在这里,虽然pytorch中会自动计算所有的结点的梯度,但是我们执行loss_G.backward()后,按照Torch的理解是,这里直接调用backward。即不仅调用了updateGradInput(我们只需要这个),还额外的计算了accGradParameters(这个是没用的),但是看到,在optimize_parameters中,只是进行 optimizer_G.step()所以只会更新G的参数。所以没有更新D(虽然此时D中有dummy gradient)。等下一回合,又调用 optimizer_D.zero_grad(), 因此会把刚才残留的D的梯度清空。所以仍旧是符合的。

自动求导反向书写的简洁

得出结论,书写自动求导的代码完全还是很简洁的。只需要进行loss计算。loss可以直接相加,然后loss.backward()即可。loss的定义比如:

self.optimizer_G = torch.optim.Adam(self.netG.parameters(),
            lr=opt.lr, betas=(opt.beta1, 0.999))

Adam是继承自Optimizer类。该类的step函数会将构建loss的所有的Variable的参数进行更新。

    def step(self, closure=None):
        """Performs a single optimization step.

        Arguments:
            closure (callable, optional): A closure that reevaluates the model
                and returns the loss.
        """
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']: 
             #如果这个参数有没有grad(这个Variable的requries_grad为False)
             #则直接跳过。
                if p.grad is None:
                    continue
                grad = p.grad.data
                state = self.state[p]

                # 对p.data进行更新!就是对参数进行更新!

                # State initialization
                if len(state) == 0:
                    state['step'] = 0
                    # Exponential moving average of gradient values
                    state['exp_avg'] = grad.new().resize_as_(grad).zero_()
                    # Exponential moving average of squared gradient values
                    state['exp_avg_sq'] = grad.new().resize_as_(grad).zero_()

                exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
                beta1, beta2 = group['betas']

                state['step'] += 1

                if group['weight_decay'] != 0:
                    grad = grad.add(group['weight_decay'], p.data)

                # Decay the first and second moment running average coefficient
                exp_avg.mul_(beta1).add_(1 - beta1, grad)
                exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)

                denom = exp_avg_sq.sqrt().add_(group['eps'])

                bias_correction1 = 1 - beta1 ** state['step']
                bias_correction2 = 1 - beta2 ** state['step']
                step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1

                p.data.addcdiv_(-step_size, exp_avg, denom)

  • 19
    点赞
  • 83
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值