【小白学PyTorch】动态图与静态图的浅显理解

这个系列是重新整理的一个《小白学PyTorch系列》。文章来自微信公众号【机器学习炼丹术】,喜欢的话动动小手关注下公众号吧~


本章节缕一缕PyTorch的动态图机制与Tensorflow的静态图机制(最新版的TF也支持动态图了似乎)。

1 动态图的初步推导

  • 计算图是用来描述运算的有向无环图
  • 计算图有两个主要元素:结点(Node)和边(Edge);
  • 结点表示数据 ,如向量、矩阵、张量;
  • 边表示运算 ,如加减乘除卷积等;

上图是用计算图表示:

y = ( x + w ) ∗ ( w + 1 ) y = ( x + w ) ∗ ( w + 1 ) y=(x+w)∗(w+1)y=(x+w)∗(w+1) y=(x+w)(w+1)y=(x+w)(w+1)

其中呢, a = x + w a=x+w a=x+w b = w + 1 b=w+1 b=w+1 , y = a ∗ b y=a∗b y=ab. (a和b是类似于中间变量的那种感觉。)

Pytorch在计算的时候,就会把计算过程用上面那样的动态图存储起来。现在我们计算一下y关于w的梯度:

∂ y ∂ w = ∂ y ∂ a ∂ a ∂ w + ∂ y ∂ b ∂ b ∂ w \frac{\partial y}{\partial w} = \frac{\partial y}{\partial a} \frac{\partial a}{\partial w} + \frac{\partial y}{\partial b} \frac{\partial b}{\partial w} wy=aywa+bywb
= 2 × w + x + 1 = 5 =2\times w + x + 1=5 =2×w+x+1=5

(上面的计算中,w=1,x=2)

现在我们用Pytorch的代码来实现这个过程:

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w+x
b = w+1
y = a*b

y.backward()
print(w.grad)

得到的结果:

2 动态图的叶子节点

这个图中的叶子节点,是w和x,是整个计算图的根基。之所以用叶子节点的概念,是为了减少内存,在反向传播结束之后,非叶子节点的梯度会被释放掉 , 我们依然用上面的例子解释:

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w+x
b = w+1
y = a*b

y.backward()
print(w.is_leaf,x.is_leaf,a.is_leaf,b.is_leaf,y.is_leaf)
print(w.grad,x.grad,a.grad,b.grad,y.grad)

运行结果是:

可以看到只有x和w是叶子节点,然后反向传播计算完梯度后(.backward()之后),只有叶子节点的梯度保存下来了。

当然也可以通过.retain_grad()来保留非任意节点的梯度值。

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w+x
a.retain_grad()
b = w+1
y = a*b

y.backward()
print(w.is_leaf,x.is_leaf,a.is_leaf,b.is_leaf,y.is_leaf)
print(w.grad,x.grad,a.grad,b.grad,y.grad)

运行结果:

3. grad_fn

torch.tensor有一个属性grad_fn,grad_fn的作用是记录创建该张量时所用的函数,这个属性反向传播的时候会用到。例如在上面的例子中,y.grad_fn=MulBackward0,表示y是通过乘法得到的。所以求导的时候就是用乘法的求导法则。同样的,a.grad=AddBackward0表示a是通过加法得到的,使用加法的求导法则。

import torch
w = torch.tensor([1.],requires_grad = True)
x = torch.tensor([2.],requires_grad = True)

a = w+x
a.retain_grad()
b = w+1
y = a*b

y.backward()
print(y.grad_fn)
print(a.grad_fn)
print(w.grad_fn)

运行结果是:

叶子节点的.grad_fn是None。

4 静态图

两者的区别用一句话概括就是:

  • 动态图:pytorch使用的,运算与搭建同时进行;灵活,易调节。
  • 静态图:老tensorflow使用的,先搭建图,后运算;高效,不灵活。

静态图我们是需要先定义好运算规则流程的。比方说,我们先给出

a = x + w a = x+w a=x+w , b = w + 1 b=w+1 b=w+1 , y = a × b y=a\times b y=a×b

然后把上面的运算流程存储下来,然后把w=1,x=2放到上面运算框架的入口位置进行运算。而动态图是直接对着已经赋值的w和x进行运算,然后变运算变构建运算图。

在一个课程http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture08.pdf中的第125页,有这样的一个对比例子:

这个代码是Tensorflow的,构建运算的时候,先构建运算框架,然后再把具体的数字放入其中。整个过程类似于训练神经网络,我们要构建好模型的结构,然后再训练的时候再吧数据放到模型里面去。又类似于在旅游的时候,我们事先定要每天的行程路线,然后每天按照路线去行动。

动态图呢,就是直接对数据进行运算,然后动态的构建出运算图。很符合我们的运算习惯。

两者的区别在于,静态图先说明数据要怎么计算,然后再放入数据。假设要放入50组数据,运算图因为是事先构建的,所以每一次计算梯度都很快、高效;动态图的运算图是在数据计算的同时构建的,假设要放入50组数据,那么就要生成50次运算图。这样就没有那么高效。所以称为动态图

动态图虽然没有那么高效,但是他的优点有以下:

  1. 更容易调试。
  2. 动态计算更适用于自然语言处理。(这个可能是因为自然语言处理的输入往往不定长?)
  3. 动态图更面向对象编程,我们会感觉更加自然。
以下是小白PyTorch的一些教程: 1. 官方文档:PyTorch提供了详细的官方文档,从安装到使用教程,以及高级深度学习开发的资料。PyTorch的第一步是查看官方文档:https://pytorch.org/docs/stable/index.html 2. PyTorch中文文档:如果英语不是很好,这是一个很好的PyTorch中文文档。虽然有一些不是很清晰或者过时的部分,但是它仍然是较好的教程之一。:https://pytorch-cn.readthedocs.io/zh/latest/ 3. PyTorch Handbook:PyTorch Handbook汇集了PyTorch的基础知识和高级技巧,适合新手习,也适合进阶使用PyTorch的人参考。:https://github.com/zergtant/pytorch-handbook 4. Udacity深度学习班“入门PyTorch”课程:入门PyTorch是Udacity的深度学习班的一门课程。 该课程提供了关于PyTorch的综合介绍,包括从张量到神经网络的构建。该课程的重点是实战:利用 PyTorch 实现著名的 MNIST 实例,训练卷积神经网络,基于迁移习的像分类等等。:https://www.udacity.com/course/deep-learning-pytorch--ud188 5. PyTorch实战教程:完整的 PyTorch 实战教程,包括深度神经网络,零件库,像和自然语言处理等:https://github.com/yunjey/pytorch-tutorial 6. PyTorch 60分钟教程:PyTorch 60分钟教程是 PyTorch 的入门课程,该课程提供了有关 PyTorch 库和 API 的指南。:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 7. 深度学习理论入门:这本书不仅介绍了深度学习领域的基础知识,还介绍了用PyTorch实现深度学习模型的方法,并且包含了许多实际案例示例。:https://github.com/huanhuanZhang/rampy/tree/main/PyTorch 以上是小白PyTorch的一些教程。PyTorch是一个强大的深度学习框架,它的文档和教程都很详细。选择合适的教程和实践,不断探索和习,才能真正掌握这个框架。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值