torch.autograd.Function的使用

(个人理解仅供参考)

1 什么情况下使用

自己定义的网络结构,没有现成的,就得手写forward和backward

2 怎么使用

2.1 forward

前向传播的表达式

2.2 backward

求导结果

2.3 举例

前向传播表达式:y = w * x + b
假设f()是我们关于y的loss函数,那么z = f(y)即为loss值
现在要求loss对w、x、b的偏导(假设只有一层):
dz/dx = dz/dy * dy/dx = dz/dy * w
dz/dw = dz/dy * dy/dw = dz/dy * x
dz/db = dz/dy * dy/db = dz/dy * 1
好在dz/dy不用我们再求了,它就是 backward 的参数grad_output。那么grad_output是从哪来的呢?其实就是 forward 会 return output 给 backward ,至于 backward 怎么把 output 变为 grad_output 就不用细究了。
所以:
dz/dx = grad_output * w
dz/dw = grad_output * x
dz/db = grad_output * 1

因此,对于y = w * x + b,我们的代码为:

import torch
from torch.autograd import Function

class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(w, x)	 # 保存参数
        output = w * x + b
        return output	# 传给backward

    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b	# 传给forward


Linear = MultiplyAdd.apply

2.4 模板

"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
    def forward(self, inputs, parameters):
        self.saved_for_backward = [inputs, parameters]
        # output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
        return output
 
    def backward(self, grad_output):
        inputs, parameters = self.saved_tensors # 或self.saved_variables
        # grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
        return grad_input
"""

2.5 验证

验证的话需要使用torch.autograd.gradcheck,给上我的完整代码,验证部分在最后:

import torch
from torch.autograd import Function, gradcheck

"""
# 使用autograd.Function进行扩展的一个模板
class My_Function(Function):
    def forward(self, inputs, parameters):
        self.saved_for_backward = [inputs, parameters]
        # output = [对输入和参数进行的操作,其实就是前向运算的函数表达式]
        return output
 
    def backward(self, grad_output):
        inputs, parameters = self.saved_tensors # 或者是self.saved_variables
        # grad_input = [求函数forward(input)关于 parameters 的导数,其实就是反向运算的导数表达式] * grad_output
        return grad_input
"""


class MultiplyAdd(Function):

    @staticmethod
    def forward(ctx, w, x, b):
        ctx.save_for_backward(w, x)
        output = w * x + b
        return output

    @staticmethod
    def backward(ctx, grad_output):
        w, x = ctx.saved_tensors
        grad_w = grad_output * x
        grad_x = grad_output * w
        grad_b = grad_output * 1
        return grad_w, grad_x, grad_b


Linear = MultiplyAdd.apply

x = torch.ones(1, requires_grad=True, dtype=torch.float64)
w = torch.rand(1, requires_grad=True, dtype=torch.float64)
b = torch.rand(1, requires_grad=True, dtype=torch.float64)

# print("start forward...")
# z = MultiplyAdd.apply(w, x, b)
# print("start backward...")
# z.backward()
#
# print(x.grad, w.grad, b.grad)

test = gradcheck(Linear, (x, w, b), eps=1e-6)
print(test)

3 存疑

现在我只是会用了这个,但是如果是两层的全连接层,这段代码是怎么工作的?这个问题我还没想明白,留个坑

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值