pytorch的自定义拓展之(二)——torch.autograd.Function完成自定义层

前言:前面的一篇文章中,已经很详细的说清楚了nn.Module、nn.functional、autograd.Function三者之间的联系和区别,虽然autograd.Function本质上是自定义函数的,但是由于神经网络、层、激活函数、损失函数本质上都是函数或者是多个函数的组合,所以使用autograd.Function依然可以达到定义层、激活函数、损失函数、甚至模型的目的,就像我们使用nn.Module是一样的,只不过更偏底层,稍微复杂一些而已,因为需要自己定义求导函数。

但是需要特别注意的是,对于复杂的层或者是网络,使用autograd.Function几乎是不可行的,因为我们需要重新定义反向求导规则即backward函数,而复杂层或者网络没办法写出每一个参数的导函数,或者是即便写出来也是异常复杂(因为链式求导法则再加上一些非线性函数的关系)所以一般不推荐使用autograd.Function去定义层,更不要去定义模型,但是一般定义一个较简单的函数还是可以的。

所以本文依然只会涉及到简单的定义操作,旨在帮助更好地理解autograd.Function的工作过程。关于autograd.Function的详细定义过程,我们可以参考前一篇文章:

pytorch的自定义拓展之(一)——torch.nn.Module和torch.autograd.Function

一、autograd.Function的实例

前面的文章都是使用的实例方法来重写forward和backward方法,下面来看一下如果使用静态方法,像Function类定义的那样,怎么实现。

1.1 重写Function类的静态方法

import torch
from torch.autograd import Function

# 类需要继承Function类,此处forward和backward都是静态方法
class MultiplyAdd(Function):  
                                                             
    @staticmethod                                  
    def forward(ctx, w, x, b):                 
        ctx.save_for_backward(w,x)    #保存参数,这跟前一篇的self.save_for_backward()是一样的
        output = w * x + b
        return output                        
         
    @staticmethod                                 
    def backward(ctx, grad_output):    #获取保存的参数,这跟前一篇的self.saved_variables()是一样的
        w,x = ctx.saved_variables  
        print("=======================================")             
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b  # backward输入参数和forward输出参数必须一一对应
 
x = torch.ones(1,requires_grad=True)  # x 是1,所以grad_w=1
w = torch.rand(1,requires_grad=True)  # w 是随机的,所以grad_x=随机的一个数
b = torch.rand(1,requires_grad=True)  # grad_b 恒等于1

print('开始前向传播')
z=MultiplyAdd.apply(w, x, b)   # forward,这里的前向传播是不一样的,这里没有使用函数去包装自定义的类,而是直接使用apply方法
print('开始反向传播')
z.backward()                   # backward

print(x.grad, w.grad, b.grad)
'''运行结果为:
开始前向传播
开始反向传播
=======================================
tensor([0.1784]) tensor([1.]) tensor([1.])
'''

注意:上面最大的不同除了使用的是静态方法以外,最大的不同在于,我没有使用一个函数去包装我的自定义类,而是直接使用了  z=MultiplyAdd.apply(w, x, b)  去完成前向运算过程,

这个apply方法是定义在Function类的父类_FunctionBase中定义的一个方法,但是这个方法到底是怎么实现的还不得而知。

这里到底是为什么,我还没有搞得特别清楚,因为我没有找到关于apply的详细代码所在,如果有大佬知道,望告知,万分感谢!

二、autograd.Function结合nn.Module

前面已经多次强调,虽然Function类可以用来定义模型,但是不要这么去做,因为Function类本身是为自定义函数而存在,我们在这里演示一下如何使用Function类自定义一个层,然后使用这个自定义的层来搭建网络。

为了简单,本文以搭建一个线性层作为实例说明:

参考下面的文章:

定义torch.autograd.Function的子类,自己定义某些操作,且定义反向求导函数

 

  • 11
    点赞
  • 55
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值