前情提要
在排查GAN训练失真问题时,对pytorch中梯度相关知识进行了回顾,特此记录,以便自我回顾。
参考文章
测试代码
import torch
import torch.nn as nn
def test_requires_grad(requires_grad=False):
torch.manual_seed(0)
x = torch.randn(2, 2)
print('============ input ======== \n {} \n ========================='.format(x))
# x.requires_grad = True
lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
lin3 = nn.Linear(2, 2)
x1 = lin0(x)
x2 = lin1(x1)
for p in lin2.parameters():
print('is_leaf: {}'.format(p.is_leaf))
p.requires_grad = requires_grad
x3 = lin2(x2)
x4 = lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.grad)
print(x.grad_fn)
print(x1.grad_fn)
print(x2.grad_fn)
print(x3.grad_fn)
print(x4.grad_fn)
def test_detach(detach=False):
torch.manual_seed(0)
x = torch.randn(2, 2)
print('============ input ======== \n {} \n ========================='.format(x))
x.requires_grad = True
print(x.is_leaf)
lin0 = nn.Linear(2, 2)
lin1 = nn.Linear(2, 2)
lin2 = nn.Linear(2, 2)
lin3 = nn.Linear(2, 2)
x1 = lin0(x)
x2 = lin1(x1)
if detach:
x3 = lin2(x2.detach())
else:
x3 = lin2(x2)
x4 = lin3(x3)
x4.sum().backward()
print(lin0.weight.grad)
print(lin1.weight.grad)
print(lin2.weight.grad)
print(lin3.weight.is_leaf, lin3.weight.grad)
print(x.grad_fn)
print(x1.grad_fn)
print(x2.grad_fn)
print(x3.grad_fn)
print(x4.grad_fn)
if __name__ == '__main__':
# test_detach(True)
# test_detach(False)
test_requires_grad(True)
test_requires_grad(False)