原文链接:详解Pytorch动态图的回溯机制
大家好,我是泰哥。《5分钟精通PyTorch》经过1个月的连载,已经介绍了张量的常规操作以及运算技巧。之后的章节就进入到深度学习部分,会以理论与代码结合的方式为大家呈现,帮助大家理解其中细节。
对于动态图回溯机制的学习与理解,首先从张量的微分计算开始入手。
注意
本节我们暂时不区分微分值、导数值、梯度值的区别,后续讲解梯度下降时再进行区分。不理解的同学统一理解为导数即可。
1 Variable
与requires_grad
有同学会说在进行微分运算时需提前将Tensor
类转化为Variable
类,但其实在PyTorch 0.4
版本后,Variable
的概念被逐渐弱化,Tensor
就不再是一个纯计算载体,可微分性也变成了Tensor
的一个基本属性,我们只需要在创建Tensor
时,通过设置requires_grad
属性为True
,就可规定张量可微分计算。
x = torch.tensor(1.,requires_grad = True)
x
# tensor(1., requires_grad=True)
此时张量a
就是一个可微分的张量,requires_grad
是它的一个属性,可以查看并修改。
# 查看可微分性
x.requires_grad
# True
# 修改可微分性
x.requires_grad = False
# 再次查看可微分性
x.requires_grad
# False
2 可微分性的属性
可微分性会体现在可微分张量参与的所有运算中。
requires_grad
属性:可微分性
# 构建可微分张量
x = torch.tensor(1., requires_grad = True)
x
# tensor(1., requires_grad=True)
# 构建函数关系
y = x ** 2
y
# tensor(1., grad_fn=<PowBackward0>)
我们发现,此时张量y
具有了一个grad_fn
属性,并且取值为<PowBackward0>
,我们可以查看该属性:
y.grad_fn
# <PowBackward0 at 0x200a2047208>
grad_fn
存储了Tensor
的微分函数,也可以说它存储了可微分张量在进行计算过程中的函数关系,此处x
到y
进行了幂运算,就是pow
方法,与上述返回属性相对应。
# 但x作为初始张量,并没有grad_fn属性
x.grad_fn
值得注意的是,y不仅和x存在幂运算关系
y
=
x
2
y = x^2
y=x2,更重要的是,y
本身是一个由x
张量计算得出的一个张量:
# 打印y
y
# tensor(1., grad_fn=<PowBackward0>)
对于一个可微分张量(x
)生成的张量(y
),也是可微分的:
y.requires_grad
# True
相比于x
,y
不仅同样拥有张量的取值,也同样可微,还额外存储了x到y的函数计算信息。
我们再尝试围绕y
创建新的函数关系:z = y + 1
:
z = y + 1
z
# tensor(2., grad_fn=<AddBackward0>)
# z同样可微
z.requires_grad
# True
# z保存了y到z的函数函数add
z.grad_fn
# <AddBackward0 at 0x200a2037648>
可以发现z
也存储了数值,并是可微的,还存储了和y
的计算关系(add
)。
3 回溯机制
在PyTorch
的张量计算过程中,如果我们设置初始张量是可微的,则在计算过程中,每一个由原张量计算得出的新张量都是可微的,并且还会保存此前一步的函数关系,这也就是所谓的回溯机制。
根据这个回溯机制,我们就能清楚地掌握张量每一步的计算过程,并据此绘制张量计算图。
4 张量计算图
借助回溯机制,我们就能将张量的复杂计算过程抽象为一张图Graph
,比如我们之前定义的x
、y
、z
三个张量,三者的计算关系就可以由下图进行表示。
计算图定义
计算图模型由节点nodes
和边edges
构成,节点表示操作符,也就是张量,节点之间的边表示张量之间的函数关系,方向则表示实际运算方向。
节点类型
在张量计算图中,虽然每个节点都表示可微分张量,但节点和节点之间却略有不同,比如在前例中:
y
和z
保存了函数计算关系,但x
没有- 可以发现
z
是所有计算的终点
因此可以将节点分为三类,分别是:
- 叶节点:初始输入的可微分张量,前例中的
x
- 输出节点:最后计算得出的张量,前例中的
z
- 中间节点:在一张计算图中,除了叶节点和输出节点,其他都是中间节点,前例中的
y
在一张计算图中,可以有多个叶节点和中间节点,但大多数情况下,只有一个输出节点,若存在多个输出结果,我们也往往会将其保存在一个向量中。
5 计算图的动态性
PyTorch
的计算图是动态计算图,会根据可微分张量的计算过程自动生成,并且伴随着新张量或运算的加入不断更新,这使得PyTorch
的计算图更加灵活高效且容易构建。
而静态图(TF1
)需要先先构建计算流程图,然后进行会话传入数据后才能执行。具体代码比较与分析可看TF与PyTorch该如何选择。
原文链接:详解Pytorch动态图的回溯机制