关于.sava_for_backward的理解
torch.autograd.function.functionctx.save_for_backward(*tensonrs)
保存传递回来的tensors以便未来backward()
save_for_backward()建议最多被调用一次,只能从forward()方法内部调用,并且只能使用Tensor张量。
所有的tensors 计划被用来向后传播应该使用save_for_backward(而不是直接在ctx)保存,来避免不正确的梯度和内存泄漏.并且启用torch.autograd.graph.saved_tensors_hooks。
注意如果中间张量,既不是forward()的输入,也不是输出张量,被保存在backward()中,你的自定义函数不支持double backward()。自定义函数(不支持向后双精度的自定义函数)以后应该装饰他们的向后传播双精度方法,使用@once_differentiable以便向后传播产生错误。如果你想要支持双精度的backward参见double backward turorial.
在backward()中,被保存的张量可以通过saved_tensors attribute来访问,在他们返回给用户之前,需要检查他们是否被用修改他们内容的原位操作。
参数也可以是None。
观看extending torch.autograd来获得该方法的更多细节。
class Func(Function):
@staticmethod
def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
w = x * z
out = x * y + y * z + w * y
ctx.save_for_backward(x, y, w, out)
ctx.z = z # z is not a tensor
return out
@staticmethod
@once_differentiable
def backward(ctx, grad_out):
x, y, w, out = ctx.saved_tensors
z = ctx.z
gx = grad_out * (y + y * z)
gy = grad_out * (x + z + w)
gz = None
return gx, gy, gz
a = torch.tensor(1., requires_grad=True, dtype=torch.double)
b = torch.tensor(2., requires_grad=True, dtype=torch.double)
d = Func.apply(a, b, c)