2-3 动态计算图

一、动态计算图简介

Pytorch的计算图由节点和边组成,节点表示张量或者Function,边表示张量和Function之间的依赖关系。
Pytorch中的计算图是动态图,那为什么是动态的呢?这里的动态主要有两重含义。
第一层含义是:计算图的正向传播是立即执行的。无需等待完整的计算图创建完毕,每条语句都会在计算图中动态添加节点和边,并立即执行正向传播得到计算结果。
第二层含义是:计算图在反向传播后立即销毁。下次调用需要重新构建计算图。如果在程序中使用了backward方法执行了反向传播,或者利用torch.autograd.grad方法计算了梯度,那么创建的计算图会被立即销毁,释放存储空间,下次调用需要重新创建。

计算图的正向传播是立即执行的

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))

print(loss.data)
print(Y_hat.data)

计算图在反向传播后立即销毁

import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.randn(10,2)
Y = torch.randn(10,1)
Y_hat = X@w.t() + b  # Y_hat定义后其正向传播被立即执行,与其后面的loss创建语句无关
loss = torch.mean(torch.pow(Y_hat-Y,2))

#计算图在反向传播后立即销毁,如果需要保留计算图(比如求多次反向传播), 需要设置retain_graph = True
loss.backward()  #loss.backward(retain_graph = True) 

#loss.backward() #如果再次执行反向传播将报错

二、计算图中的Function

计算图中的 张量我们已经比较熟悉了, 计算图中的另外一种节点是Function, 实际上就是 Pytorch中各种对张量操作的函数。
这些Function和我们Python中的函数有一个较大的区别,那就是它同时包括正向计算逻辑和反向传播的逻辑。
我们可以通过继承torch.autograd.Function来创建这种支持反向传播的Function

class MyReLU(torch.autograd.Function):
    #正向传播逻辑,可以用ctx存储一些值,供反向传播使用。
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input) # 用ctx保存input
        return input.clamp(min=0) # 使用clamp将负值置零
    #反向传播逻辑
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors # 取出保存的值
        grad_input = grad_output.clone() # 复制
        grad_input[input < 0] = 0
        return grad_input
    
import torch 
w = torch.tensor([[3.0,1.0]],requires_grad=True)
b = torch.tensor([[3.0]],requires_grad=True)
X = torch.tensor([[-1.0,-1.0],[1.0,1.0]])
Y = torch.tensor([[2.0,3.0]])

relu = MyReLU.apply # relu现在也可以具有正向传播和反向传播功能
Y_hat = relu(X@w.t() + b) # 以下计算loss 就可以反向传播
loss = torch.mean(torch.pow(Y_hat-Y,2))

loss.backward() 

print(w.grad)
print(b.grad)

image.png

三、计算图与反向传播

了解了Function的功能,我们可以简单地理解一下反向传播的原理和过程。理解该部分原理需要一些高等数学中求导链式法则的基础知识。

import torch 

x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

loss.backward()

loss.backward()语句调用后,依次发生以下计算过程。
1,loss自己的grad梯度赋值为1,即对自身的梯度为1。
2,loss根据其自身梯度以及关联的backward方法,计算出其对应的自变量即y1和y2的梯度,将该值赋值到y1.grad和y2.grad。(loss——>y)
3,y2和y1根据其自身梯度以及关联的backward方法, 分别计算出其对应的自变量x的梯度,x.grad将其收到的多个梯度值累加。(y——>x,所以有loss——>x)
(注意,1,2,3步骤的求梯度顺序和对多个梯度值的累加规则恰好是求导链式法则的程序表述)
正因为求导链式法则衍生的梯度累加规则,张量的grad梯度不会自动清零,在需要的时候需要手动置零。

四、叶子节点和非叶子节点

执行下面代码

import torch 

x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

loss.backward()
print("loss.grad:", loss.grad)
print("y1.grad:", y1.grad)
print("y2.grad:", y2.grad)
print(x.grad)

image.png
我们会发现 loss.grad并不是我们期望的1,而是 None。类似地 y1.grad 以及 y2.grad也是 None。
这是为什么呢?这是由于它们不是叶子节点张量。
在反向传播过程中,只有 is_leaf=True 的叶子节点,需要求导的张量的导数结果才会被最后保留下来。
那么什么是叶子节点张量呢?叶子节点张量需要满足两个条件。
1,叶子节点张量是由用户直接创建的张量,而非由某个Function通过计算得到的张量。
2,叶子节点张量的 requires_grad 属性必须为True.
image.png
那有什么解决办法呢?
利用retain_grad可以保留非叶子节点的梯度值,利用register_hook可以查看非叶子节点的梯度值。

import torch ·

#正向传播
x = torch.tensor(3.0,requires_grad=True)
y1 = x + 1
y2 = 2*x
loss = (y1-y2)**2

#非叶子节点梯度显示控制
y1.register_hook(lambda grad: print('y1 grad: ', grad))
y2.register_hook(lambda grad: print('y2 grad: ', grad))
loss.retain_grad()

#反向传播
loss.backward()
print("loss.grad:", loss.grad)
print("x.grad:", x.grad)

image.png
注意这里计算 x 的梯度:
image.png

参考:https://github.com/lyhue1991/eat_pytorch_in_20_days

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值