模仿pytorch中的自动微分的做法自己实现一个简单的自动求梯度backward
背景
学习pytorch的人学习到pytorch底层基本都会接触到pytorch的自动求梯度操作,pytorch的自动微分操作在很多文章都有提到,基本原理就是维护一个计算图。但是说是这么说,自己实现一个又该如何实现呢?怎么把链式计算以代码的方式来表达?在网上看到关于这方面的基本都是泛泛而谈,知道维护一个计算图,但是是啥也不清晰。这里会写一个简单的自动微分来让读者对这个部分理解的更加深刻,更加明白pytorch中计算图是怎么维护和计算的。
原理计算
下面先讲一下原理计算,说明白原理计算我们再接着写代码部分。
我们先设立个目标,比如y=2x2+3x+1这个函数,我们需要拆分计算图。我们知道这个函数里面的公式是由一个个四则运算组合而成,我们要做的是将这个四则运算一个个拆解开来。
这样子拆解就可以很好用代码写了,可以看出用递归来计算每个梯度就好了。
代码
class Variable:
def __init__(self, value, grad=0.0):
self.value = value # 变量的值
self.grad = grad # 变量的梯度
self.x = None
self.y = None
self.if_has_x_y = False
self.model = None
def backward(self):
if self.if_has_x_y:
if self.model == 1:
self.x.grad += 1 * self.grad
self.y.grad += 1 * self.grad
if self.x.if_has_x_y != None:
self.x.backward()
if self.y.if_has_x_y != None:
self.y.backward()
elif self.model == 2:
self.x.grad += self.y.value * self.grad # 链式法则应用在乘法中
self.y.grad += self.x.value * self.grad
if self.x.if_has_x_y != None:
self.x.backward()
if self.y.if_has_x_y != None:
self.y.backward()
def grad_zero(self):
self.grad = 0
if(self.x != None):
self.x.grad_zero()
if(self.y != None):
self.y.grad_zero()
def add(x, y):
if type(x) != Variable or type(y) != Variable:
assert("input error")
z = Variable(x.value + y.value) # 定义加法操作
z.x = x
z.y = y
z.if_has_x_y = True
z.model = 1
return z
def mul(x, y):
if type(x) != Variable or type(y) != Variable:
assert("input error")
z = Variable(x.value * y.value) # 定义乘法操作
z.x = x
z.y = y
z.if_has_x_y = True
z.model = 2
return z
首先用python实现一个数据结构,存放梯度,四则运算模式以及是否存在x和y,然后再分别写加法的梯度修改和乘法的梯度修改就好了。
加法和乘法函数没啥好看的,逻辑也不难看懂
下面是运行案例(说一下这个案例里面的y.grad,其实就是y对y自己求导,有一个点是设置其为1也是为了后面的梯度计算不用为其设置特例):
# 计算 y = 2 * x**2 + 3 * x + 1 的导数
x = Variable(2.0) # 初始化 x = 2.0
# 前向传播
a = mul(Variable(2.0), mul(x, x)) # 2 * x^2
b = mul(Variable(3.0), x) # 3 * x
y = add(add(a, b), Variable(1.0)) # y = 2 * x^2 + 3 * x + 1
# 反向传播
y.grad = 1.0 # y 的梯度设为 1,表示对自己求导为 1
y.backward() # 触发反向传播,计算 x 的梯度
print(f"dy/dx at x=2: {x.grad}") # 输出 x 的梯度
下面是结果:
dy/dx at x=2: 11.0
案例二:
x = Variable(2.0)
y = mul(x,mul(x,x))
y.grad = 1.0 # y 的梯度设为 1,表示对自己求导为 1
y.backward() # 触发反向传播,计算 x 的梯度
print(f"dy/dx at x=2: {x.grad}")
y.grad_zero()
print(f"dy/dx at x=2: {x.grad}")
结果:
dy/dx at x=2: 12.0
dy/dx at x=2: 0