pytorch 笔记: 扩展torch.autograd

1 扩展torch.autograd

        向 autograd 添加操作需要为每个操作实现一个新的 Function 子类。

         回想一下,函数是 autograd 用来编码操作历史和计算梯度的东西。

2 何时使用

        通常,如果您想在模型中执行不可微分或依赖非 Pytorch 库中有的函数(例如 NumPy)的计算,但仍希望您的操作与其他操作链接并使用 autograd 引擎,可以通过扩展torch.autograd实现自定义函数 .

        在某些情况下,还可以使用自定义函数来提高性能和内存使用率:如果您使用扩展函数实现了前向和反向传递,则可以将它们包装在 Function 中以与 autograd 引擎交互。

3 何时不建议使用

        如果您已经可以根据 PyTorch 的内置操作编写函数,那么它的后向图(很可能)已经能够被 autograd 记录。 在这种情况下,您不需要自己实现反向传播功能。考虑使用一个普通的pytorch 函数即可。

4 如何使用

采取以下步骤:

1. 创建子类Function, 并实现 forward() 和 backward() 方法。

2. 在 ctx 参数上调用正确的方法。

3. 声明你的函数是否支持双反向传播(double backward)。

4. 使用 gradcheck 验证您的梯度(反向传播的实现)是否正确。

4.1 第一步

        在创建类 Function 之后,您需要定义 2 个方法:

  • forward() 是执行操作的代码(前向传播)

        它可以采用任意数量的参数,其中一些是optimal的(如果设定默认值的话)。

        这里接受各种 Python 对象。

        跟踪历史的张量参数(即 requires_grad=True 的tensor)将在调用之前转换为不跟踪历史的张量参数【但它们如何被使用将在计算图中注册】。请注意,此逻辑不会遍历列表/字典/任何其他数据结构,只会考虑作为调用的直接参数的张量。

        如果有多个输出,您可以返回单个张量输出或张量元组。

  • backward() 定义梯度公式。

      它将被赋予与和forward的输出一样多的张量参数,其中每个参数都代表对应输出的梯度。重要的是永远不要就地修改这些梯度(即不要有inplace操作)。

        它应该返回与forward输入一样多的张量,每个张量都包含对应输入的梯度。

        如果您的输入不需要梯度(needs_input_grad 是一个布尔元组,指示每个输入是否需要梯度计算),或者是非张量对象,您可以返回 None。

        此外,如果你有 forward() 的可选参数,你可以返回比输入更多的梯度,只要它们都是 None。

4.2 第二步

        需要正确使用 forward 的 ctx 相关函数,以确保新函数与 autograd 引擎一起正常工作。

  • save_for_backward() 可以 保存稍后在反向传播中需要使用的、前向的输入张量 。   

                任何东西,即非张量   和既不是输入也不是输出的张量,都应该直接存储在 ctx 上。

  • mark_dirty() 必须用于标记任何由 forward 函数就地修改的输入。
  • 使用 mark_non_differentiable() 来告诉autograd 引擎某一个输出是否不可微。默认情况下,所有可微分类型的输出张量都将设置为需要梯度。不可微分类型(即整数类型)的张量永远不会被标记为需要梯度。
  • set_materialize_grads() 可用于告诉 autograd 引擎在输出不依赖于输入的情况下优化梯度计算,方法是不物化给予后向函数的梯度张量。也就是说,如果设置为 False,python 中的 None 对象或 C++ 中的“未定义张量”(x.defined() 为 False 的张量 x)将不会在向后调用之前转换为填充零的张量,因此您的代码将需要处理这些对象,就好像它们是用零填充的张量一样。此设置的默认值为 True。  

4.3 第三步

        如果你的函数不支持两次反向传播double backward,你应该通过使用 once_differentiable() 向后修饰来明确声明它。 使用此装饰器,尝试通过您的函数执行双重反向传播double backward将产生错误。

4.4 第四步

        建议您使用 torch.autograd.gradcheck() 来检查您的后向函数是否正确计算前向梯度,方法是使用后向函数计算雅可比矩阵,并将值元素与使用有限差分数值计算的雅可比进行比较

5 示例

from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):

    @staticmethod
    # 注意这里forward和backward都是静态函数
    def forward(ctx, input, weight, bias=None):
        # bias 是一个可选变量,所以可以没有梯度

        ctx.save_for_backward(input, weight, bias)
        #这里就使用了step2中的save_for_backward
        #也就是保存稍后在反向传播中需要使用的、前向的输入或输张量

        output = input.mm(weight.t())

        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output
        #类似的前向传播输出定义

    
    @staticmethod
    def backward(ctx, grad_output):
        #由于这个方法只有一个输出(output),因而在backward中,只需要有一个输入即可(ctx不算的话)

        input, weight, bias = ctx.saved_tensors
        #这呼应的就是前面的ctx.save_for_backward

        grad_input = grad_weight = grad_bias = None

    
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)
        '''
        这些 needs_input_grad 检查是可选的,只是为了提高效率。 
        如果你想让你的代码更简单,你可以跳过它们。 
        为不需要的输入返回梯度不会返回错误。 
        '''

        return grad_input, grad_weight, grad_bias

现在,为了更容易使用这些自定义操作,我们建议为它们的 apply 方法设置别名:

linear = LinearFunction.apply

这样之后,linear的效果就和我们正常的比如'loss=torch.nn.MSELoss'的loss差不多了 

5.1 用非tensor 参数化的方法

class MulConstant(Function):
    @staticmethod
    def forward(ctx, tensor, constant):
        ctx.constant = constant
        #非tensor的变量就不用save_for_backward了,直接存在ctx里面即可
        return tensor * constant

    @staticmethod
    def backward(ctx, grad_output):

        return grad_output * ctx.constant, None
        #非tensor 变量的梯度为0

6 检查效果

        您可能想检查您实现的反向传播方法是否实际计算了函数的导数。 可以通过与使用小的有限差分的数值近似进行比较:

from torch.autograd import gradcheck


input = (torch.randn(20,20,dtype=torch.double,requires_grad=True), 
         torch.randn(30,20,dtype=torch.double,requires_grad=True))
test = gradcheck(linear, input, eps=1e-6, atol=1e-4)
'''
gradcheck 将张量的元组作为输入,检查用这些张量评估的梯度是否足够接近数值近似值,

如果它们都验证了这个条件,则返回 True。 
'''
print(test)
#True

  • 1
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值