PyTorch结构介绍

转载自:

https://blog.csdn.net/u012436149/article/details/70145598

https://blog.csdn.net/douhaoexia/article/details/78821428

对PyTorch架构的粗浅理解,不能保证完全正确,但是希望可以从更高层次上对PyTorch上有个整体把握。水平有限,如有错误,欢迎指错,谢谢!

先来看一下整个优化过程:首先调用前向(forward)通道来计算输出和损失,然后调用反向通道(backward)得到模型的导数。最后按照权重合并这些导数更新模型以期最小化损失。   前向传播计算损失,反向传播损失优化,更新各个网络权重。


backward:核心:BP算法,即利用梯度反传优化损失函数。(完了把BP算法自己重新好好推导一下)

optimizer.zero_grad() #将参数的grad值初始化为0

    optimizer = torch.optim.SGD(itertools.chain(net1.parameters(), net2.parameters()),lr=0.001, momentum=0.9) # 这里 net1 和net2 优化的先后没有区别 !!
    optimizer.zero_grad() #将参数的grad值初始化为0
    # forward + backward + optimize
    outputs1 = net1(inputs)            #input 未置requires_grad为True,但不影响
    outputs2 = net2(outputs1)
    loss = criterion(outputs2, labels) #计算损失
    loss.backward()                    #反向传播  


几个重要的类型

和数值相关的

  • Tensor
  • Variable
  • Parameter
  • buffer(这个其实不能叫做类型,其实他就是用来保存tensor的)

Tensor
PyTorch中的计算基本都是基于Tensor的,可以说是PyTorch中的基本计算单元。

Variable: 
Tensor的一个Wrapper,其中保存了Variable的创造者,Variable的值(tensor),还有Variable的梯度(Variable)。

自动求导机制的核心组件,因为它不仅保存了 变量的值,还保存了变量是由哪个op产生的。这在反向传导的过程中是十分重要的。

Variable的前向过程的计算包括两个部分的计算,一个是其值的计算(即,Tensor的计算),还有就是Variable标签的计算。标签指的是什么呢?如果您看过PyTorch的官方文档 Excluding subgraphs from backward 部分的话,您就会发现Variable还有两个标签:requires_gradvolatile。标签的计算指的就是这个。

Parameter
这个类是Variable的一个子集,PyTorch给出这个类的定义是为了在Module(下面会谈到)中添加模型参数方便。

模型相关的

  • Function
  • Module

Function
如果您想在PyTorch中自定义OP的话,您需要继承这个类,您需要在继承的时候复写forwardbackward方法,可能还需要复写__init__方法(由于篇幅控制,这里不再详细赘述如果自定义OP)。您需要在forward中定义OP,在backward说明如何计算梯度。 
关于Function,还需要知道的一点就是,Functionforwardbackward方法中进行计算的类型都是Tensor,而不是我们传入的Variable。计算完forward和backward之后,会包装成Varaible返回。这种设定倒是可以理解的,因为OP是一个整体嘛,OP内部的计算不需要记录creator

Module
这个类和Function是有点区别的,回忆一下,我们定义Function的时候,Funciton本身是不需要变量的,而Module是变量和Function的结合体。在某些时候,我们更倾向称这种结构为Layer。但是这里既然这么叫,那就这么叫吧。

Module实际上是一个容器,我们可以继承Module,在里面加几个参数,从而实现一个简单全连接层。我们也可以继承Module,在里面加入其它Module,从而实现整个VGG结构。

关于hook

PyTorch中注册的hook都是不允许改变hook的输入值的 
下面对PyTorch中出现hook的地方做个总结: 
* Module : register_forward_hook, register_backward_hook 
注意:forward_hook不能用来修改Module的输出值,它的功能就像是安装个监视器一样。我们可以用forward_hook和visdom来监控我们Module的输出。backward_hook和与Variable的功能是类似的,将和Variableregister_hook一起介绍。

  • Variable: register_hook 
    Variable的register_hook注册的是一个backward hookbackward hook是在BP的过程中会用到的。可以用它来处理计算的梯度。

关于hook较为详尽的介绍

foward过程与backward过程

forward 
以一个Module为例: 
1. 调用module的call方法 
2. modulecall里面调用moduleforward方法 
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下 
4. 调用Functioncall方法 
5. Functioncall方法调用了Function的forward方法。 
6. Functionforward返回值 
7. moduleforward返回值 
8. 在modulecall进行forward_hook操作,然后返回值。

backward 
关于backward

总结

PyTorch基本的操作是OP,被操作数是Tensor

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值