torch.autograd.Function自定义反向求导

Function自定义反向求导规则

Extending torch.autograd

在某些情况下我们的函数不可微(not differentiable),但是我们仍然需要对他求导时,就需要我们自定义求导方式,这里我们根据PyTorch官网给出的例子,来看一下torch.autograd.Function是如何运行的

官网给出的例子为LinearFunction,代码如下,这里我们假设输入为 2 × 3 2\times3 2×3的矩阵 x x x,权重也为 2 × 3 2\times3 2×3的矩阵 w w w,bias为 2 × 2 2 \times 2 2×2的矩阵 b b b,则

forward
y = x ∗ w T + b y = x*w^{T}+b y=xwT+b

backward

∂ y ∂ x = w T     ∂ y ∂ w = x     ∂ y ∂ b = 1 \frac{\partial y}{\partial x} = w^{T} ~~~ \frac{\partial y}{\partial w} = x ~~~ \frac{\partial y}{\partial b} = 1 xy=wT   wy=x   by=1

from numpy import double
import torch
from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

input  = torch.tensor([[2.0, 1.5, 2.5], [1.0, 2.0, 3.0]], dtype=torch.double, requires_grad=True)
weight = torch.tensor([[3.0, 2.0, 3.5], [1.0, 2.0, 3.0]], dtype=torch.double, requires_grad=True)
bias   = torch.tensor([0.1, 0.2], dtype=torch.double, requires_grad=True)

# two ways to use linear operation
# first
output = LinearFunction.apply(input, weight, bias)
print(output)
# second
linear = LinearFunction.apply
output = linear(input, weight, bias)
print(output)


# 检查backward是否计算正确
from torch.autograd import gradcheck
# gradchek takes a tuple of tensor as input, check if your gradient
# evaluated with these tensors are close enough to numerical
# approximations and returns True if they all verify this condition.
test = gradcheck(LinearFunction.apply, (input, weight, bias), eps=1e-6, atol=1e-4)
print(test)  # 没问题的话输出True
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值