Retain_graph and Create_graph作用
- retain_graph的作用
a = torch.tensor(1, requries_grad=True)
b = torch.tensor(1, requries_grad=True)
c = a**2
d=b*c
c.backward()
d.backward()
这段代码中,执行完c.bckward()
之后graph会自动free,无法计算d.backward()
. 如果retain_graph=True
则可以计算d.backward()
- create_graph作用
gradients = grad(y_,x,\
grad_outputs=torch.ones_like(y_).to('cuda'),\
create_graph=True, \
retain_graph=True, \
only_inputs=False,
allow_unused=True)[0]
这样计算完gradients之后,如果选择create_graph=True
则从x,w到gradients的图会被建立起来,可以进一步地使gradients对x和w求导。