模仿pytorch中的自动微分的做法自己实现一个简单的自动求梯度backward

模仿pytorch中的自动微分的做法自己实现一个简单的自动求梯度backward

背景

学习pytorch的人学习到pytorch底层基本都会接触到pytorch的自动求梯度操作,pytorch的自动微分操作在很多文章都有提到,基本原理就是维护一个计算图。但是说是这么说,自己实现一个又该如何实现呢?怎么把链式计算以代码的方式来表达?在网上看到关于这方面的基本都是泛泛而谈,知道维护一个计算图,但是是啥也不清晰。这里会写一个简单的自动微分来让读者对这个部分理解的更加深刻,更加明白pytorch中计算图是怎么维护和计算的。

原理计算

下面先讲一下原理计算,说明白原理计算我们再接着写代码部分。
我们先设立个目标,比如y=2x2+3x+1这个函数,我们需要拆分计算图。我们知道这个函数里面的公式是由一个个四则运算组合而成,我们要做的是将这个四则运算一个个拆解开来。
请添加图片描述

这样子拆解就可以很好用代码写了,可以看出用递归来计算每个梯度就好了。

代码

class Variable:
    def __init__(self, value, grad=0.0):
        self.value = value  # 变量的值
        self.grad = grad    # 变量的梯度
        self.x = None
        self.y = None
        self.if_has_x_y = False
        self.model = None
        
        
    def backward(self):
        if self.if_has_x_y: 
            if self.model == 1:
                self.x.grad += 1 * self.grad  
                self.y.grad += 1 * self.grad
                if self.x.if_has_x_y != None:          
                    self.x.backward()
                if self.y.if_has_x_y != None:
                    self.y.backward()
            elif self.model == 2:
                self.x.grad += self.y.value * self.grad  # 链式法则应用在乘法中
                self.y.grad += self.x.value * self.grad
                if self.x.if_has_x_y != None:          
                    self.x.backward()
                if self.y.if_has_x_y != None:
                    self.y.backward()
    
    def grad_zero(self):
        self.grad = 0
        if(self.x != None):
            self.x.grad_zero()
        if(self.y != None):
            self.y.grad_zero()

def add(x, y):
    if type(x) != Variable or type(y) != Variable:
        assert("input error")
    z = Variable(x.value + y.value)  # 定义加法操作
    z.x = x
    z.y = y
    z.if_has_x_y = True
    z.model = 1
    return z

def mul(x, y):
    if type(x) != Variable or type(y) != Variable:
        assert("input error")
    z = Variable(x.value * y.value)  # 定义乘法操作
    z.x = x
    z.y = y
    z.if_has_x_y = True
    z.model = 2
    return z

首先用python实现一个数据结构,存放梯度,四则运算模式以及是否存在x和y,然后再分别写加法的梯度修改和乘法的梯度修改就好了。
加法和乘法函数没啥好看的,逻辑也不难看懂
下面是运行案例(说一下这个案例里面的y.grad,其实就是y对y自己求导,有一个点是设置其为1也是为了后面的梯度计算不用为其设置特例):

# 计算 y = 2 * x**2 + 3 * x + 1 的导数
x = Variable(2.0)  # 初始化 x = 2.0

# 前向传播
a = mul(Variable(2.0), mul(x, x))  # 2 * x^2
b = mul(Variable(3.0), x)          # 3 * x
y = add(add(a, b), Variable(1.0))  # y = 2 * x^2 + 3 * x + 1

# 反向传播
y.grad = 1.0  # y 的梯度设为 1,表示对自己求导为 1
y.backward()  # 触发反向传播,计算 x 的梯度

print(f"dy/dx at x=2: {x.grad}")  # 输出 x 的梯度

下面是结果:

dy/dx at x=2: 11.0

案例二:

x = Variable(2.0)
y = mul(x,mul(x,x))
y.grad = 1.0  # y 的梯度设为 1,表示对自己求导为 1
y.backward()  # 触发反向传播,计算 x 的梯度

print(f"dy/dx at x=2: {x.grad}")

y.grad_zero()
print(f"dy/dx at x=2: {x.grad}")

结果:

dy/dx at x=2: 12.0
dy/dx at x=2: 0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值