pytorch如何去定义新的自动求导函数

pytorch定义新的自动求导函数

在pytorch中想自定义求导函数,通过实现torch.autograd.Function并重写forward和backward函数,来定义自己的自动求导运算。参考官网上的demo:传送门

直接上代码,定义一个ReLu来实现自动求导
 

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

import torch

class MyRelu(torch.autograd.Function):

    @staticmethod

    def forward(ctx, input):

        # 我们使用ctx上下文对象来缓存,以便在反向传播中使用,ctx存储时候只能存tensor

        # 在正向传播中,我们接收一个上下文对象ctx和一个包含输入的张量input;

        # 我们必须返回一个包含输出的张量,

        # input.clamp(min = 0)表示讲输入中所有值范围规定到0到正无穷,如input=[-1,-2,3]则被转换成input=[0,0,3]

        ctx.save_for_backward(input)

         

        # 返回几个值,backward接受参数则包含ctx和这几个值

        return input.clamp(min = 0)

    @staticmethod

    def backward(ctx, grad_output):

        # 把ctx中存储的input张量读取出来

        input, = ctx.saved_tensors

         

        # grad_output存放反向传播过程中的梯度

        grad_input = grad_output.clone()

         

        # 这儿就是ReLu的规则,表示原始数据小于0,则relu为0,因此对应索引的梯度都置为0

        grad_input[input < 0] = 0

        return grad_input

 

进行输入数据并测试

1

2

3

4

5

6

7

8

9

10

11

12

13

14

15

16

17

18

19

20

21

22

23

24

25

26

27

28

29

30

31

32

dtype = torch.float

device = torch.device('cuda' if torch.cuda.

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值