torch.nn是专门为神经网络设计的模块化接口,nn.Module是nn中十分重要的类。在介绍该模块前,我们先看下pytorch官方对该模块的注释:
根据官方注释我们了解到Module类是所有神经网络模块的基类,Module可以以树形结构包含其他的Module。Module类中包含网络各层的定义及forward方法,下面介绍我们如何定义自已的网络:
-
需要继承nn.Module类,并实现forward方法;
-
一般把网络中具有可学习参数的层放在构造函数__init__()中;
-
不具有可学习参数的层(如ReLU)可在forward中使用nn.functional来代替;
-
只要在nn.Module的子类中定义了forward函数,利用Autograd自动实现反向求导。
那么这时候就有一些疑问:
- 为什么要继承nn.Module?
- forward函数什么时候会被调用?
answer:
1、关于第一个问题,我们需要看下Module类的源码,Module初始化后就相当于8个有序字典,因此,当实例化你定义的Net(nn.Module的子类)时,要确保父类的构造函数首先被调用,这样才能确保上述8个OrderedDict被create。
_modules:桥梁作用,在获取一个net的所有的parameters的时候,是通过递归遍历该net的所有_modules来实现的。
2、forward函数需要通过Net(input)(Net为自己定义的类)来调用,而非Net.forward(input),因为前者实现了额外的功能:
a) 先执行_forward_pre_hooks里的所有hooks
b) 再调用forward函数
c) 执行_forward_hooks中所有hooks
d) 执行_backward_hooks中所有hooks
_forward_pre_hooks通常只有一些Norm操作会定义_forward_pre_hooks,这种hook不能改变input的内容;_forward_hooks不改变input和output,目前就是方便自己测试的时候用;_backward_hooks和_forward_hooks类似。所以,网络中没有Norm操作,使用Net(input)和Net.forward(input)是等价的。
以上是针对torch.nn.Module模块的介绍,构建模型过程中的一些问题总结及理解后续更新。