pytorch求导
主要记录一下backward()函数以及torch.autograd.grad()函数的使用,并重点探究一下对应的creat_graph
以及retain_graph
参数的使用。
首先定义函数:
import torch
from torch.autograd import Variable
def f(x):
y = x ** 3
return y
1. backward()
主要用于对loss的求导,并将求出的梯度存入叶子节点对应的buffer内。
其中,retain_graph=True
是保留整个计算图不被销毁,这样可以多次backward(),但是要注意梯度是多次累积的。
def main1():
x = Variable(torch.tensor([5.0]), requires_grad=True)
y = f(x)
y.backward(retain_graph=True)
print(x.grad) # 75
y.backward(retain_graph=True)
print(x.grad) # 75 + 75
y.backward()
print(x.grad) # 75 + 75 + 75
2. torch.autograd.grad()
- 首先利用
torch.autograd.grad()
函数进行求导不会在叶子节点的buffer内保留梯度 - 输出的结果为tuple类型,即为一阶导数的值
- 当使用
creat_graph=True
时,可以计算高阶导数,此时grad_x.requires_grad
自动为True
def main2():
x = Variable(torch.tensor([5.0]), requires_grad=True)
grad_x = torch.autograd.grad(f(x), x, create_graph=True)
print(grad_x) # 一阶导数 75 grad_x.requires_grad == True
grad_grad_x = torch.autograd.grad(grad_x[0], x, create_graph=True)
print(grad_grad_x) # 二阶导数 30 requires_grad == True
grad_grad_grad_x = torch.autograd.grad(grad_grad_x[0], x)
print(grad_grad_grad_x) # 三阶导数 6 requires_grad == False
- 其中另一个参数
retain_graph=True
表示的是保留计算图的结果,这样在backward()
的时候会进行累积,单独使用不可以求二阶导,这是因为requires_grad == False
,下面两段代码会报错
def main3():
x = Variable(torch.tensor([5.0]), requires_grad=True)
grad_x = torch.autograd.grad(f(x), x, retain_graph=True)
print(grad_x)
grad_grad_x = torch.autograd.grad(grad_x[0], x, retain_graph=True)
print(grad_grad_x)
grad_grad_grad_x = torch.autograd.grad(grad_grad_x[0], x)
print(grad_grad_grad_x)
def main4():
x = Variable(torch.tensor([5.0]), requires_grad=True)
grad_x = torch.autograd.grad(f(x), x, retain_graph=True)
print(grad_x)
grad_x[0].backward()
print(x.grad)
3. 混合使用
def main5():
x = Variable(torch.tensor([5.0]), requires_grad=True)
grad_x = torch.autograd.grad(f(x), x, create_graph=True)
print(grad_x) # 一阶导数 75
grad_x[0].backward()
print(x.grad) # buffer内存的是30,也就是二阶导数的值