文章目录
模型创建与nn.Module
1、模型创建步骤
模型的创建示意图如下:
从上图中可以看出,模型的创建与权值初始化共同构成了模型,模型的创建只要包括了:
- 构建网络层:卷积层,池化层,激活函数等;
- 拼接网络层:网络层有构建网络层后,需要进行网络层的拼接,拼接成
L
e
N
e
t
LeNet
LeNet,
A
l
e
x
N
e
t
AlexNet
AlexNet和
R
e
s
N
e
t
ResNet
ResNet等
创建好模型后,需要对模型进行权值初始化,PyTorch
中的初始化方法主要有:Xavier
,Kaiming
,均匀分布,正态分布等方法。
2、nn.Module
- 第一部分中讲到的模型的创建与权值初始化在
PyTorch
中均需要通过nn.Module
来完成,nn.Module
是整个模块的根基 nn.Module
是torch.nn
中的模块,torch.nn
中一共有四个模块,如下图所示:
nn.Module
中有八个重要的属性用于管理整个模型:parameters
: 存储管理nn.Parameter
类modules
:存储管理nn.Module
类buffers
:存储管理缓冲属性,如BN层中的running_mean
***_hooks
:共有5个,存储管理钩子函数
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
3、以LeNet
模型为例探究nn.Module
- 如图所示,
LeNet
由很多网络层构成,包括两个卷积层,两个池化层和三个全连接层(LeNet: Conv1 -> pool1 -> Conv2 -> pool2 -> fc1 -> fc2 -> fc3
)
- 将上图转为一个计算图的形式,如下图所示,计算图有两个主要的概念:一个是节点一个是边,节点就是张量数据,边就是运算,在图中就是箭头
- 构建模型有两要素,第一是构建子模块,比如
LeNet
是由很多网络层构成的,所以首先得构建子模块中的网络层;构建好网络层后,第二是拼接子模块,按照一定拓扑结构拼接子模块就可以得到模型,构建子模块需要用到__init__()
函数,而拼接子模块需要用到forward()
函数,下面针对这两个函数进行讲解
(1)初始化部分:__init__()
class LeNet(nn.Module):
def __init__(self, classes):
super(LeNet, self).__init__() # 继承父类nn.Module的初始化
self.conv1 = nn.Conv2d(3, 6, 5) # 卷积层,卷积核为5*5,输入通道为3,输出通道为6
self.conv2 = nn.Conv2d(6, 16, 5) # 卷积层
self.fc1 = nn.Linear(16*5*5, 120)c # 全连接层
self.fc2 = nn.Linear(120, 84) # 全连接层
self.fc3 = nn.Linear(84, classes) # 全连接层
#####(2)拼接部分:forward()
def forward(self, x):
out = F.relu(self.conv1(x)) # import torch.nn.functional as F
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
(3)nn.Module
的属性构建
nn.Module
的属性构建会在module
类中进行属性赋值的时候会被setattr()
函数拦截,在这个函数当中会判断即将要赋值的数据类型是否是nn.parameters
类,如果是的话就会存储到parameters
字典中;如果是module
类就会存储到modul
字典中