【tips】with torch.no_grad(): 或 t.detach()
x = torch.tensor(1.0, requires_grad=True)
print("x:", x) # x: tensor(1., requires_grad=True)
y1 = x ** 2
print("y1:", y1) # y1: tensor(1., grad_fn=<PowBackward0>)
y2 = x ** 3
print("y2:", y2) # y2: tensor(1., grad_fn=<PowBackward0>)
y2 = y2.detach()# 效果等同于:with torch.no_grad():
"""
with torch.no_grad():
y2 = x ** 3
"""
print("y2d:", y2) # y2d: tensor(1.)
y3 = y1 + y2
y3.backward()
print(x.grad) # tensor(2.)
# 若y2没逃离梯度计算,则x.grad=5