繁星*春水

关于.sava_for_backward的理解

torch.autograd.function.functionctx.save_for_backward(*tensonrs)

保存传递回来的tensors以便未来backward()

save_for_backward()建议最多被调用一次,只能从forward()方法内部调用,并且只能使用Tensor张量。

所有的tensors 计划被用来向后传播应该使用save_for_backward(而不是直接在ctx)保存,来避免不正确的梯度和内存泄漏.并且启用torch.autograd.graph.saved_tensors_hooks。

注意如果中间张量,既不是forward()的输入,也不是输出张量,被保存在backward()中,你的自定义函数不支持double backward()。自定义函数(不支持向后双精度的自定义函数)以后应该装饰他们的向后传播双精度方法,使用@once_differentiable以便向后传播产生错误。如果你想要支持双精度的backward参见double backward turorial.

在backward()中,被保存的张量可以通过saved_tensors attribute来访问,在他们返回给用户之前,需要检查他们是否被用修改他们内容的原位操作。

参数也可以是None。

观看extending torch.autograd来获得该方法的更多细节。

 class Func(Function):
     @staticmethod
     def forward(ctx, x: torch.Tensor, y: torch.Tensor, z: int):
         w = x * z
         out = x * y + y * z + w * y
         ctx.save_for_backward(x, y, w, out)
         ctx.z = z  # z is not a tensor
         return out

     @staticmethod
     @once_differentiable
     def backward(ctx, grad_out):
        x, y, w, out = ctx.saved_tensors
        z = ctx.z
        gx = grad_out * (y + y * z)
        gy = grad_out * (x + z + w)
        gz = None
        return gx, gy, gz

 a = torch.tensor(1., requires_grad=True, dtype=torch.double)
 b = torch.tensor(2., requires_grad=True, dtype=torch.double)
 d = Func.apply(a, b, c)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值