利用torch.autograd.Function自定义层的forward和backward

当我们在Pytorch中想自定义某一层的梯度计算时,可以利用torch.autograd.Function来封装一个class,此时可以我们可以自己在backward方法中自定求解梯度的方法,也适用于不可导函数的backward计算。

这个函数的源代码可以从如下链接获取:
https://pytorch.org/docs/stable/_modules/torch/autograd/function.html

首先给出一个官方提供的demo:

class Exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = torch.exp(i)
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

这个函数的forward里计算了:
y = e x y=e^{x} y=ex
backward里给定了 y y y的梯度(grad_output),计算 x x x的梯度:
g r a d _ o u t p u t ∗ ∂ y ∂ x = g r a d _ o u t p u t ∗ e x grad\_output*\frac{\partial{y}}{\partial x}=grad\_output*e^{x} grad_outputxy=grad_outputex

  • 首先介绍forward函数,此函数必须接受一个context ctx作为第一个参数,之后可以传入任何参数,ctx可以利用save_for_backward来保存tensors,在backward阶段可以进行获取。forward里定义了前向传播的路径。
  • 之后介绍backward函数,此函数必须接受一个context ctx作为第一个参数,然后是第二个参数grad_output,里面存储forward后tensor的梯度。return的结果是对应每一个input(对应于forward里的各个input)的梯度。ctx.needs_input_grad作为一个boolean型的表示也可以用来控制每一个input是否需要计算梯度,e.g., ctx.needs_input_grad[0] = False,表示forward里的第一个input不需要梯度,若此时我们return时这个位置的梯度值表示为None即可。

利用apply方法可以使用层:其中定义的对象像一个function一样在不同位置可以重复利用。

import torch


class Exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = torch.exp(i)
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

exp = Exp()
x1 = torch.tensor([3., 4.], requires_grad=True)
x2 = exp.apply(x1)
y2= exp.apply(x2)

y2.sum().backward()
print(x1.grad)

再提供部分一个复杂一点的线性转换的例子:
y = x W T + b y=xW^{T}+b y=xWT+b
∂ y ∂ x = W , ∂ y ∂ W = x , ∂ y ∂ b = 1 \frac{\partial{y}}{\partial x}=W,\frac{\partial{y}}{\partial W}=x,\frac{\partial{y}}{\partial b}=1 xy=W,Wy=x,by=1
这里定义了Linear模块:

import torch
import torch.nn as nn


class LinearFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input, weight, bias = ctx.saved_variables
        grad_input = grad_weight = grad_bias = None

        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)

        return grad_input, grad_weight, grad_bias


class Linear(nn.Module):
    def __init__(self, input_features, output_features, bias=True):
        super(Linear, self).__init__()
        self.input_features = input_features
        self.output_features = output_features

        self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(output_features))
        else:
            self.register_parameter('bias', None)

        self.weight.data.uniform_(-0.1, 0.1)
        if bias is not None:
            self.bias.data.uniform_(-0.1, 0.1)

    def forward(self, input):
        # See the autograd section for explanation of what happens here.
        return LinearFunction.apply(input, self.weight, self.bias)
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页