import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
b = torch.tensor([2, 3, 4.], requires_grad=True)
n = a*2
n2 = n.detach()
f = n2 + 3*a/b
#detach 用法 阻断梯度传播 比如此时n2就没有梯度 但是a有 如果把对应a改成b 则b也有
f.sum().backward()
print(a.grad)
对于经常出现的round函数 本身没有梯度 可以采用
w_1 = round(w)-w
w_2 = w_1.detach()
w_3 = w_2 + w
这种方式 绕过对round求梯度 采取用w的梯度代替
代码运行结果:
tensor([3., 3., 3.])