转载自:
https://blog.csdn.net/u012436149/article/details/70145598
https://blog.csdn.net/douhaoexia/article/details/78821428对PyTorch架构的粗浅理解,不能保证完全正确,但是希望可以从更高层次上对PyTorch上有个整体把握。水平有限,如有错误,欢迎指错,谢谢!
先来看一下整个优化过程:首先调用前向(forward)通道来计算输出和损失,然后调用反向通道(backward)得到模型的导数。最后按照权重合并这些导数更新模型以期最小化损失。 前向传播计算损失,反向传播损失优化,更新各个网络权重。
backward:核心:BP算法,即利用梯度反传优化损失函数。(完了把BP算法自己重新好好推导一下)
optimizer.zero_grad() #将参数的grad值初始化为0
optimizer = torch.optim.SGD(itertools.chain(net1.parameters(), net2.parameters()),lr=0.001, momentum=0.9) # 这里 net1 和net2 优化的先后没有区别 !!
optimizer.zero_grad() #将参数的grad值初始化为0
# forward + backward + optimize
outputs1 = net1(inputs) #input 未置requires_grad为True,但不影响
outputs2 = net2(outputs1)
loss = criterion(outputs2, labels) #计算损失
loss.backward() #反向传播
几个重要的类型
和数值相关的
- Tensor
- Variable
- Parameter
- buffer(这个其实不能叫做类型,其实他就是用来保存tensor的)
Tensor: PyTorch
中的计算基本都是基于Tensor
的,可以说是PyTorch
中的基本计算单元。
Variable: Tensor
的一个Wrapper
,其中保存了Variable
的创造者,Variable
的值(tensor),还有Variable
的梯度(Variable
)。
自动求导机制的核心组件,因为它不仅保存了 变量的值,还保存了变量是由哪个op
产生的。这在反向传导的过程中是十分重要的。
Variable
的前向过程的计算包括两个部分的计算,一个是其值的计算(即,Tensor的计算),还有就是Variable
标签的计算。标签指的是什么呢?如果您看过PyTorch的官方文档 Excluding subgraphs from backward
部分的话,您就会发现Variable
还有两个标签:requires_grad
和volatile
。标签的计算指的就是这个。
Parameter:
这个类是Variable
的一个子集,PyTorch
给出这个类的定义是为了在Module
(下面会谈到)中添加模型参数方便。
模型相关的
- Function
- Module
Function:
如果您想在PyTorch
中自定义OP
的话,您需要继承这个类,您需要在继承的时候复写forward
和backward
方法,可能还需要复写__init__
方法(由于篇幅控制,这里不再详细赘述如果自定义OP
)。您需要在forward
中定义OP
,在backward
说明如何计算梯度。
关于Function
,还需要知道的一点就是,Function
中forward
和backward
方法中进行计算的类型都是Tensor
,而不是我们传入的Variable。计算完forward和backward之后,会包装成Varaible返回。这种设定倒是可以理解的,因为OP是一个整体嘛,OP内部的计算不需要记录creator
Module:
这个类和Function
是有点区别的,回忆一下,我们定义Function
的时候,Funciton
本身是不需要变量的,而Module
是变量和Function
的结合体。在某些时候,我们更倾向称这种结构为Layer
。但是这里既然这么叫,那就这么叫吧。
Module
实际上是一个容器,我们可以继承Module
,在里面加几个参数,从而实现一个简单全连接层。我们也可以继承Module
,在里面加入其它Module
,从而实现整个VGG
结构。
关于hook
PyTorch中注册的hook都是不允许改变hook的输入值的
下面对PyTorch中出现hook的地方做个总结:
* Module : register_forward_hook, register_backward_hook
注意:forward_hook不能用来修改Module的输出值,它的功能就像是安装个监视器一样。我们可以用forward_hook和visdom来监控我们Module的输出。backward_hook和与Variable
的功能是类似的,将和Variable
的register_hook
一起介绍。
- Variable: register_hook
Variable的register_hook注册的是一个backward hook
,backward hook
是在BP的过程中会用到的。可以用它来处理计算的梯度。
foward过程与backward过程
forward
以一个Module为例:
1. 调用module的call
方法
2. module
的call
里面调用module
的forward
方法
3. forward
里面如果碰到Module
的子类,回到第1步,如果碰到的是Function
的子类,继续往下
4. 调用Function
的call
方法
5. Function
的call
方法调用了Function的forward
方法。
6. Function
的forward
返回值
7. module
的forward
返回值
8. 在module
的call
进行forward_hook
操作,然后返回值。
backward
关于backward
总结
PyTorch基本的操作是OP
,被操作数是Tensor
。