深度之眼Pytorch框架训练营第四期——PyTorch的计算图与动态图机制

Pytorch的计算图与动态图机制

1、计算图(Computational Graph)
  • 计算图是一个用来描述运算的有向无环图
  • 计算图有两个主要元素:结点(Node)和(Edge):
  • 结点表示数据:向量,矩阵,张量等
  • 表示运算,如加减乘除卷积等
  • 例子:利用计算图表示 y = ( x + w ) ∗ ( w + 1 ) y=(x+w) *(w+1) y=(x+w)(w+1)
    在这里插入图片描述
  • 第一步:创建 x x x w w w
  • 第二步:令 a = x + w , b = w + 1 , y = a ∗ b a=x+w, b=w+1, y=a * b a=x+w,b=w+1,y=ab
    这样就可以得到如上图所示的计算图,利用计算图来描述运算的好处不仅仅是让运算更加简洁,还有一个更加重要的作用是使梯度求导更加方便,例如上图中,如果需要求解 ∂ y ∂ w \frac{\partial y}{\partial w} wy,则可以按照下图所示的步骤求解:

在这里插入图片描述

计算过程为:
∂ y ∂ w = ∂ y ∂ a ∂ a ∂ w + ∂ y ∂ b ∂ b ∂ w = b ∗ 1 + a ∗ 1 = b + a = ( w + 1 ) + ( x + w ) = 2 ∗ w + x + 1 = 2 ∗ 1 + 2 + 1 = 5 \begin{aligned} \frac{\partial y}{\partial w} = & \frac{\partial y}{\partial a} \frac{\partial a}{\partial w}+\frac{\partial y}{\partial b} \frac{\partial b}{\partial w} \\ = & b * 1+a * 1\\ = & b + a \\ = & (w+1)+(x+w) \\ = & 2 * w+x+1 \\ = & 2 * 1+2+1 \\ = & 5 \end{aligned} wy=======aywa+bywbb1+a1b+a(w+1)+(x+w)2w+x+121+2+15
本质上, y y y w w w求导就是在计算图中找到所有 y y y w w w的路径,把路径上的导数进行求和[red]

w = torch.tensor([1.], requires_grad=True)
x = torch.tensor([2.], requires_grad=True)
a = torch.add(w, x)     # retain_grad()
b = torch.add(w, 1)
y = torch.mul(a, b)
y.backward()
print(w.grad)
# tensor([5.]) 与理论值相符
  • 计算图深入分析(以上图为例):
  • 张量的属性中有一个与梯度相关的属性——is_leaf,也就是叶子节点,功能就是是指示张量是否是叶子节点,如果是用户创建的节点,则为叶子节点 x , w x,w x,w),而通过计算得到的节点则不是叶子节点 a , b , y a,b,y a,b,y
  • 叶子节点是整个计算图的根基,例如前面求导的计算图,在前向传导中的 a , b , y a,b,y a,b,y都要依据创建的叶子节点 x , y x,y x,y进行计算的;在反向传播过程中,所有梯度的计算也都要依赖叶子节点
  • 设置叶子节点主要目的是为了节省内存,Pytorch在梯度反向传播结束之后,非叶子节点的梯度都会被释放掉,而叶子结点的梯度会保留下来,如果想保留非叶子结点梯度,可以使用retain_grad()进行保留
  • 张量的属性中还有一个属性——grad_fn,作用是记录创建该张量时所用的方法(函数),在梯度反向传播的时候会用到这个属性,例如上图中,y在反向传播的时候会记录y是用乘法得到的,所用在求解ab的梯度的时候就会用到乘法的求导法则去求解ab的梯度。而求解 a , b a,b a,b的梯度时,由于 a a a b b b是通过加法得到的,不对使用链式法则求导
print("is_leaf:\n", w.is_leaf, x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
# is_leaf:
#  True True False False False
# 查看梯度
print("gradient:\n", w.grad, x.grad, a.grad, b.grad, y.grad)
# gradient:
#  tensor([5.]) tensor([2.]) None None None
# 查看 grad_fn
print("grad_fn:\n", w.grad_fn, x.grad_fn, a.grad_fn, b.grad_fn, y.grad_fn)
# grad_fn:
#  None None <AddBackward0 object at 0x7f8f839f1a50> <AddBackward0 object at 0x7f8f81e77c90> <MulBackward0 object at 0x7f8f839feb50>
2、动态图机制

根据计算图搭建方式,可将计算图分为动态图静态图

  • 静态图:先搭建图,后运算(高效但不灵活)
    在这里插入图片描述
  • 动态图:运算与搭建同时进行(灵活易调节)
    在这里插入图片描述

静态图是先将图搭建好之后,再放数据进去;而动态图,是根据每一步的计算搭建的

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值