Pytorch自动求导机制、自定义激活函数和梯度

Pytorch自动求导机制、自定义激活函数和梯度

前言:

由于pytorch框架只是提供了正向传播的机制,模块中的参数的梯度是通过自动求导推倒出来的,当我们需要自定义某一个针对张量的一些列操作时候就不够用了。

1 自动求导机制

Pytorch会根据计算过程来自动生成动态计算图,然后可以根据动态图的创建过程进行反向传播,计算得到每个节点的梯度直。

1.0 张量本身grad_fn

为了能记录张量的梯度,首先需要在张量创建的时候设置 requires_grad =True.

对于pytorch来说,每一个张量都有一个grad_fn方法,这个方法包含着创建该张量的运算的导数信息。本身携带计算图的信息,该方法还有一个next_functions属性,包含链接该张量的其他张量的grad_fn。

1.1 torch.autograd

Pytorch提供了一个专门用来做自动求导的包,torch.autograd.

包含2个重要函数:

1.1.1 torch.autograd.backward

这个函数通过传入根节点张量,以及初始梯度张量,可以计算产生该根节点所对应的叶子节点的梯度。

当张量为标量张量的时候(及只有一个元素的张量)可以部传入初始梯度张量,默认会设置初始梯度张量为1。

当计算梯度张量的时候,原先建立的计算图会自动释放,如果直接再次求导,肯定就会报错。

如果要在反向传播的时候保留计算图,可以设置retain_graph= True.

在自动求导的时候默认是不会建立反向传播图的,如果需要反向传播计算的同时建立和梯度张量相关的计算图,可以设置create_graph=Ture.

另外,对于一个可到的张量,也可以直接调用该张量内部的backward函数来自动求导。

t1=torch.randn(3,3,requires_grad=True)
t2 =t1.pow(2).sum()
#t2对t1张量求导
t2.backward()#反向传播
t1.grad
t2 =t1.pow(2).sum()
t2.backward()#再次反向传播
t1.grad #梯度累计
t1.grad.zero_() # 单个张量清零

1.1.2 torch.autograd.grad

在某些情况下,我们并不需要求出当前张量对所有产生该张量的叶子节点的梯度,这时候我们可以使用torch.autograd.grad方法。

该函数有2个参数,第一个参数是计算图的数据结果张量,第二个参数是需要对计算图求导的张量,最后输出的结果是第一个参数对第二个参数的求导结果,这个输出梯度也是会累计的。

要注意的地方:

1、这个函数部会改变叶子节点的grad属性。

2、反向传播求导时,自动释放计算图,如果要保留,可以设置retain_graph= True.

3、如果需要反向传播计算图,可以设置create_graph=Ture.

t1=torch.randn(3,3,requires_grad=True)
t2 =t1.pow(2).sum()
#t2对t1张量求导
torch.autograd.grad(t2,t1)

2 自定义激活函数和梯度

前言里说了,仅仅使用模块有时候是不能满足我们需要效果的。我们需要自定义激活函数,在激活函数中定义前向传播和反向传播的代码来实现自己的需求。

2.1 类及方法

Pytorch自定义激活函数继承于torch.autograd.Function,其内部有2个静态方法:forward和backward

class Func(torch.autograd.Function):
    @staticmethod
    def forward(ctx,input):
        return result
    
    @staticmethod
    def backward(ctx,grad_output):
        return grad_output

2.2 实例

Quoc V.Le等人的研究成果中,将Swish激活函数定义为

在这里插入图片描述

可以看到,这个公式还是比较复杂的,如果要生成图,中间有部少计算节点。

有了公式之后,我们可以求出导数函数,这样方便进行反向传播。

有了激活函数和其导数函数,我们就可以来自定义相关激活函数了。

swish =Swish.apply #获得激活函数
torch.autograd.gradcheck(
swish,torch.randn(
10,requires_grad =True,
dtype =torch.double)
)
#测试反向传播,正常返回值为True

class Swish(torch.autograd.Function):
    @staticmethod
    def forward(ctx,input):
        ctx.input =input
        return input*torch.sigmoid(1*input) #假设b=1
     @staticmethod
    def backward(ctx,grad_output):
        ctx.input =input
        tmp = torch.sigmoid(1*input)
        
        return grad_output*(tmp +1 *input*tmp(1-tmp))
    

2.3 tips

在上面代码可以看到,我们记录了前像传播和反向传播的过程,并且在backward方法中实现了数值梯度的方法。

可以通过讲apply方法赋值给一个变量的方法来激活自定义的激活函数。

为了保持梯度精度,我们一般都使用双精度类型为张量数值类型

  • 4
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

小菜学AI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值