pytorch基础知识整理(一)自动求导机制

torch.autograd

torch.autograd是pytorch最重要的组件,主要包括Variable类和Function类,Variable用来封装Tensor,是计算图上的节点,Function则定义了运算操作,是计算图上的边。

1.Tensor
tensor张量和numpy数组的区别是它不仅可以在cpu上运行,还可以在GPU上运行。
tensor其实包含一个信息头和一个数据存储类型torch.Storage,torch.Storage是一个单一数据类型的连续一维数组。可以用tensor.is_contiguous()检验张量的数据存储是否是在连续内存上,tensor必须连续才能够使用view操作改变tensor形状,如果不连续则可以使用tensor.contiguous()使之连续。

2.Variable
注:自Pytorch0.4.0版本之后,variable类型和tensor类型合并,在代码中不用再把tensor转换为variable。

var = Variable(tensor, requires_grad=True)

在计算图中,只要有一个节点使用了requires_grad=True,它的后续关联节点都会成为requires_grad=True,就是说都需要计算梯度。可以

var = Variable(tensor, volatile=True)

只要有一个节点使用volatile=True整个计算图就不会调用.backward(),用于推理过程。
注意在Variable不支持inplace运算操作,因为这样导致变量值被更改,反向传播的时候无法再使用,因此为了避免计算错误,计算图中出现inplace操作时, pytorch会报错。
可用tensor.is_leaf()判断某个变量在计算图中是否是叶子节点,只有叶子节点会保留grad,其他张量不保留,如果非叶子节点需要保留梯度,则使用tensor.retain_grad()即可。

2.1 .backward()对计算图进行反向传播更新梯度
对于计算图中的一个标量,比如损失函数的输出loss,可以直接进行.backward()操作。如果是一个张量,比如中间过程,则必须指定和该张量同形状的grad_tensor,具体涉及到反向传播过程(复合函数链式法则求偏导)的jacobian矩阵。
2.2 torch.nn.Parameter()
Parameter是Variable的子类,但parameter类会出现在模型的参数列表中(即会出现在model.parameters()迭代器中),且parameter类默认requires_grad=True,且不能设置volatile。
2.3 冻结网络部分参数
可以用detach()把张量从计算图中分离出来,分离出来的变量不求梯度,可以用来冻结部分网络权重参数(代码示例待补充)。也可以通过设置网络前面部分参数的requires_grad=False来冻结网络。

model = torchvision.models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
model.fc = nn.Linear(512, 100)
optimizer = optim.SGD(model.fc.parameters(), lr=1e-2, momentum=0.9)

3.Function
Function是对Variable进行的运算,定义了forward()方法和backward()方法。可以与nn.Module()对比来理解。两者都可以实现运算,但是Function无法保存参数,用于不需要更新参数的操作,例如各种激活函数、池化等运算,而Module可以保存参数,则用于线性层、卷积层等运算。使用Function自定义运算时必须重写forward()和backward()方法,而使用Module自定义运算时,只需要写forward()即可,backward()可由Module中的各种组件自动求解了。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值