pytorch篇--梯度及求导

目录
1.pytorch相关函数: Variable、requires_grad、volatile、no_grad、detach()
2. 什么时候使用
3. 参考资料

正文

推测阶段(inteference)指预测阶段,并不反向求导,只有前向计算,求导及梯度更新只出现在训练阶段。
可以参考pytorch官网:
自动求导机制:https://www.pytorchtutorial.com/docs/notes/autograd/
torch.autograd:https://www.pytorchtutorial.com/docs/package_references/torch-autograd/
推荐直接看官网更清楚。

1. 相关函数

a. Variable

pytorch中常使用variable将tensor转化为可梯度更新的形式,即封装了tensor,并整合了反向传播的相关实现。

Varibale包含三个属性:
data:存储了Tensor,是本体的数据
grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致
grad_fn:指向Function对象,用于反向传播的梯度计算之用

可以使用.backward()进行反向传播

import torch
from torch.autograd import Variable
# #Varibale 默认时不要求梯度的,如果要求梯度,需要说明
x = Variable(torch.one(2,2), requires_grad = True)
print(x)#其实查询的是x.data,是个tensor
y = x.sum()
 
y
"""
  Variable containing:
   4
  [torch.FloatTensor of size 1]
"""
 
 
#可以查看目标函数的.grad_fn方法,它用来求梯度
y.grad_fn
"""
    <SumBackward0 at 0x18bcbfcdd30>
"""
 
y.backward()  # 反向传播
x.grad  # Variable的梯度保存在Variable.grad中,为dy/dx
"""
  Variable containing:
   1  1
   1  1
  [torch.FloatTensor of size 2x2]
"""
 
 
#grad属性保存在Variable中,新的梯度下来会进行累加,可以看到再次求导后结果变成了2,
y.backward()
x.grad  # 可以看到变量梯度是累加的
"""
    Variable containing:
     2  2
     2  2
    [torch.FloatTensor of size 2x2]
"""
 
#所以要归零
x.grad.data.zero_()  # 归零梯度,注意,在torch中所有的inplace操作都是要带下划线的,虽然就没有.data.zero()方法
 
"""
 0  0
 0  0
[torch.FloatTensor of size 2x2]
"""
 
 
#对比Variable和Tensor的接口,相差无两
x = Variable(torch.ones(4, 5))
 
y = torch.cos(x)                         # 传入Variable
x_tensor_cos = torch.cos(x.data)  # 传入Tensor
 
print(y)
print(x_tensor_cos)
 
"""
Variable containing:
 0.5403  0.5403  0.5403  0.5403  0.5403
 0.5403  0.5403  0.5403  0.5403  0.5403
 0.5403  0.5403  0.5403  0.5403  0.5403
 0.5403  0.5403  0.5403  0.5403  0.5403
[torch.FloatTensor of size 4x5]
 
 
 0.5403  0.5403  0.5403  0.5403  0.5403
 0.5403  0.5403  0.5403  0.5403  0.5403
 0.5403  0.5403  0.5403  0.5403  0.5403
 0.5403  0.5403  0.5403  0.5403  0.5403
[torch.FloatTensor of size 4x5]
"""

具体可参考链接https://www.cnblogs.com/hellcat/p/8439055.html

在python之前的版本中,variable可以封装tensor,计算反向传播梯度时需要将tensor封装在variable中。但是在pytorch 0.4版本之后,将variable和tensor合并,也就是说不需要将tensor封装在variable中就可以计算梯度。tensor具有variable的性质。

作为能否autograd的标签,requires_grad现在是Tensor的属性,所以,只要当一个操作的任何输入Tensor具有requires_grad = True的属性,autograd就可以自动追踪历史和反向传播了。
具体可参考链接:https://zhuanlan.zhihu.com/p/93502821

a. requires_grad

Variable变量的requires_grad的属性默认为False,若一个节点requires_grad被设置为True,那么所有依赖它的节点的requires_grad都为True。

x=Variable(torch.ones(1))
w=Variable(torch.ones(1),requires_grad=True)
y=x*w
x.requires_grad,w.requires_grad,y.requires_grad
Out[23]: (False, True, True)
b. volatile && no_grad()

在pytorch 0.4版本及之后已被移除,换成no_grad()。使用volatile时会收到warning。
volatile=True是Variable的另一个重要的标识,它能够将所有依赖它的节点全部设为volatile=True,其优先级比requires_grad=True高。因而volatile=True的节点不会求导,即使requires_grad=True,也不会进行反向传播,对于不需要反向传播的情景(inference,测试推断),该参数可以实现一定速度的提升,并节省一半的显存,因为其不需要保存梯度。

>>> x = torch.tensor([1], requires_grad=True)
>>> with torch.no_grad():
...   y = x * 2
>>> y.requires_grad
False
>>> @torch.no_grad()
... def doubler(x):
...     return x * 2
>>> z = doubler(x)
>>> z.requires_grad
False
c. no_grad()和requires_grad区别

当一个节点设为no_grad()时,其依赖的所有节点均为no_grad()。而对requires_grad, 所有节点为false时,输出才为false。

d. detach()

返回一个新的Variable,从当前图中分离下来的。
即使之后重新将它的requires_grad置为true,它也不会具有梯度grad

这样我们就会继续使用这个新的Variable进行计算,后面当我们进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播

2. 什么时候使用

测试时推荐使用no_grad()。当你确定你甚至不会调用.backward()时。它比任何其他自动求导的设置更有效——它将使用绝对最小的内存来评估模型。volatile(no_grad())也决定了require_grad is False。

#可以写在很多地方,比如testloader之前
with torch.no_grad():
	for img,target in testloader:

Variable: 当读取数据时,img和target都要转为variable形式,这样计算loss的时候都为variable形式,可以计算loss的反向传播

requires_grad=Fasle时不需要更新梯度, 适用于冻结某些层的梯度;

detach()可以用于减少显存计算。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值