Pytorch学习——入门实例(四)继承Autograd.Function子类来自定义函数并实现网络的反向传播

pytorch:定义一个新的Autograd函数

在pytorch的内部,每个Autograd操作符实际上包括对张量的两种操作:

forward函数:从输入向量计算输出向量

backward函数:是接受输出向量对于某个标量的梯度,然后计算输入向量相对于这个标量的梯度。

在pytorch中,可以定义torch.autograd.Function的子类来自定义自己的Autograd函数,并且实现forward与backward两个功能。

接着可以通过构造一个实例,送入输入数据的张量,来使用自定义的Autograd操作符。

接下来是一个例子,在这个例子中使用:

y=a+bP_{3}(c+dx)

其中P3是公式:\frac{1}{2}(5x^{3}-3x)

可以对x手动求导,得到backward中的公式。

通过自定义Autograd函数来计算p3的正向和反向(其中P3函数是我们自定义的legendrePolynomial3)

在这个类中,分别实现了forward与backward两个方法,

import torch
import math
class legendrePolynomial3(torch.autograd.Function):
    #声明静态方法:在使用这个方法的时候,类不需要实例化
    @staticmethod
    def forward(ctx,input):
        ctx.save_for_backward(input)
        return 0.5*(5*input**3 - 3*input)
    @staticmethod
    def backward(ctx,grad_output):
        input, = ctx.saved_tensors
        return grad_output*1.5*(5*input**2 -1)
if __name__ == '__main__':
    dtype = torch.float
    device = torch.device('cpu')

    x = torch.linspace(-math.pi,math.pi,2000,dtype=dtype)
    y = torch.sin(x)

    a = torch.full((),0.0,device=device,dtype = dtype,requires_grad=True)
    b = torch.full((),-1.0,device=device,dtype = dtype,requires_grad=True)
    c = torch.full((),0.0,device=device,dtype = dtype,requires_grad=True)
    d = torch.full((),0.3,device=device,dtype = dtype,requires_grad=True)
    learning_rate = 5e-6
    for t in range(2000):
        #应用函数,使用apply方法
        P3 = legendrePolynomial3.apply
        y_pred = a+b*P3(c+d*x)
        loss = (y_pred - y).pow(2).sum()
        if t%100 ==99:
            print(t,loss.item())
        loss.backward()
        with torch.no_grad():
            a -= learning_rate*a.grad
            b -= learning_rate*b.grad
            c -= learning_rate*c.grad
            d -= learning_rate*d.grad
            a.grad=None
            b.grad = None
            c.grad = None
            d.grad = None
    print(f'结果为: y={a.item()}+{b.item()}*P3({c.item()}+{d.item()}x)')

补充知识:

torch.full函数:

第一个参数是张量的形状,第二个值是一个标量,根据字面意思理解,使用这个标量,填充满这个形状的张量

>>> torch.full((2, 3), 3.141592)
tensor([[ 3.1416,  3.1416,  3.1416],
        [ 3.1416,  3.1416,  3.1416]])

 

  • 2
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值