PyTorch内部机制的理解

本文详细解析PyTorch中的反向传播机制,包括理解计算图、参数更新、optimizer.zero_grad()的作用、detach()的功能以及requires_grad和volatile的含义。此外,还探讨了在多个Loss和复杂网络结构中如何计算梯度及更新参数,以及nn.Module的重要特性及其使用技巧。
摘要由CSDN通过智能技术生成

反向传播与参数更新的理解

  首先,反向传播或许被称为“反向求导”更加合适,因为它只是个求导的过程,即计算中间参数的梯度。在PyTorch中,通过loss.backward()进行反向求导,关于loss.backward()有两点需要注意:【1】loss是标量(零维张量),只有标量才能直接使用 backward();【2】loss.backward()的完整写法是loss.backward(retain_graph=False),其中的形参retain_graph的意义在于是否保留计算图,默认为False,即在反向传播后,自动释放当前计算图,节省资源的同时,为下一次反向传播做好准备。
  一般情况下是每次迭代,只需一次 forward() 和一次 backward() 。但是不排除,由于自定义loss等有多个,网络需要计算多个不同lossbackward()产生的梯度,来更新参数。于是,如果在当前loss.backward()后,还需要执行其他lossbackward(),那么就需要在当前的loss.backward()时,指定保留计算图,即loss.backward(retain_graph=True)
  需要特别注意的是,反向传播求导与网络参数更新是两个不同的过程。必须要先反向求导,再进行参数更新,其代码逻辑大致如下:

loss.backward() # 先在构建好的计算图进行反向传播,计算中间变量的梯度,计算完后,立即释放计算图
optimizer.step() # 根据计算好的梯度,进行网络的参数更新

计算图的概念

  计算图可以说是输入变量一直到输出变量的逻辑运算关系,是模型前向forward() 和后向求梯度backward() 的流程参照。这里需要注意的是,能获取回传梯度(grad)的只有计算图的叶节点(即输入节点,在loss.backward()后)。中间节点的梯度在计算求取并回传之后就会被释放掉,没办法获取。想要获取中间节点梯度,可以使用 register_hook (钩子)函数工具。当然, register_hook 不仅仅只有这个作用。

网络中间变量的梯度

  如上所述,,能获取回传梯度(grad)的只有计算图的叶节点(即输入节点,在loss.backward()后)。中间节点的梯度在计算求取并回传之后就会被释放掉,没办法获取。想要获取中间节点梯度,可以使用 register_hook (钩子)函数工具。

为什么需要optimizer.zero_grad()

  在编写Pytorch的训练代码时,下面一段代码是非常常见的:

	for i in range(batch): # 在每个Batch都执行如下操作
		# zero the parameter gradients
        optimizer.zero_grad()
        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

  在这里明确一下zero_grad()函数的作用以及为什么需要在每个batch都执行该操作:根据pytorch中backward()函数的计算,当进行求导时,梯度是累积计算而不是被替换,但在处理每一个batch时并不需要与其他batch的梯度混合起来累积计算,因此需要对每个batch调用一遍zero_grad()将当前可变参数的梯度置0。
  当然,我们也可以不选择每个batch都清除一次梯度,比如两次或多次再清除一次这样相当于提高了batch_size,对GPU的内存需要更高。

detach()的理解

  detach() ,如果 x 为中间输出,x_1 = x.detach 表示创建一个与 x 相同,但requires_grad==False 的新Tensor (相当于是把x_1以前的计算图 grad_fn 都消除了),x_1也就成了叶节点(输入节点)。原先反向传播时,回传到x时还会继续,而现在回到x_1处后,就结束了,不继续回传求导了,在x之前的网络参数也就不再进行更新了。另外值得注意的是detach_() 表示不创建新张量,而是直接修改 x 本身。

requires_grad和volatile的理解

  如果对于张量 x ,如果 x.requires_grad == True , 则表示它可以参与求导,也可以从它向后求导。默认情况下,一个新的Tensor的 requires_gradFalse
  可以向后求导的意思是说,requires_grad == True 具有传递性,如果:

x.requires_grad == True 
y.requires_grad == False  
z = f(x,y)

z.requires_grad == True,注意requires_grad == False 则不具有传递性,在PyTorch中,凡是参与运算的变量(包括输入、输出、中间输出、网络权重参数等),都可以设置requires_grad。但是一般来说,输入是否含requires_grad=True是无所谓的,因为需要更新的是网络权重的参数,如果通过nn方法调用卷积层等,会默认其中的参数的requires_grad=True。当然,设置输入的requires_grad=True会更保险,但是可能会带来一些计算和内存上的代价。
  volatilePyTorch1.0版本后已经被移除了。实际上,volatile==True 就等价于 requires_grad==Falsevolatile==True 同样具有传递性。一般只用在推理过程中。若是从某个中间输出x张量 开始都只需做推理,而不需反传梯度的话,那么只需设置x.volatile=True ,那么 x 以后的运算过程得到的输出张量均为 volatile==Tru

  • 3
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值