一、网络模型的创建
模型构建的两个要素:
- 构建子模块:在自己建立的模型(继承nn.Module)的_init_()方法
- 拼接子模块:是在模型的forward()方法中
init 函数构建子模块
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
forward 函数拼接子模块
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out
创建模型时:net = LeNet()会调用 __init__()
训练模型时:output = net(inputs)会进入nn.module中的__call__()函数
二、nn.Module的属性
- 在模型的概念当中,有一个非常重要的概念叫做nn.Module, 我们所有的模型,所有的网络层都是继承于这个类的。
torch.nn是pytorch的神经网络模块,这里的Module就是它的模块之一,还有几个与Module并列的子模块, 这些子模块协同工作,各司其职。
2.1 nn.Module属性
在nn.Module的__init__()方法中,创建8个重要的属性,
def __init__(self):
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
- _parameters: 存储管理属于nn.Parameter类的属性,例如权值,偏置这些参数
- _modules: 存储管理nn.Module类, 比如LeNet中,会构建子模块,卷积层,池化层,就会 存储在_modules中
- _buffers:存储管理缓冲属性, 如BN层中的running_mean, std等都会存在这里面
- ***_hook:存储管理钩子函数(5个与hooks有关的字典,这个先不用管)
2.2 nn.Module属性构建:
在nn.Module类中进行属性赋值时,被setattr函数拦截,在该函数中,判断即将要赋值的这个数据类型是否是nn.Parameter类,是则存储到parameters这个字典中;如果是nn.Module类,则存储在module这个字典中进行管理
2.3 nn.module总结:
一个module可以包含多个子module(LeNet包含卷积层,池化层,全连接层)
一个module相当于一个运算, 必须实现forward函数(从计算图的角度去理解)
每个module都有8个字典管理它的属性(最常用的就是_parameters,_modules)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
三、模型容器Containers
3.1 nn.Seuential
nn.Sequential 是 nn.module的容器,用于按顺序包装一组网络层。
1、根据输入类型可分为:
- 非字典
'''------------Sequential---------------'''
class LeNetSequential(nn.Module):
def __init__(self, classes):
super(LeNetSequential, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes),)
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
- 字典型
class LeNetSequentialOrderDict(nn.Module):
def __init__(self, classes):
super(LeNetSequentialOrderDict, self).__init__()
self.features = nn.Sequential(OrderedDict({
'conv1': nn.Conv2d(3, 6, 5),
'relu1': nn.ReLU(inplace=True),
'pool1': nn.MaxPool2d(kernel_size=2, stride=2),
'conv2': nn.Conv2d(6, 16, 5),
'relu2': nn.ReLU(inplace=True),
'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
}))
self.classifier = nn.Sequential(OrderedDict({
'fc1': nn.Linear(16*5*5, 120),
'relu3': nn.ReLU(),
'fc2': nn.Linear(120, 84),
'relu4': nn.ReLU(inplace=True),
'fc3': nn.Linear(84, classes),
}))
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
3.2
nn.ModuleList:是nn.module的容器,用于包装一组网络层,以迭代方式调用网络层。
主要方法:
- append():在ModuleList后面添加网络层
- extend():拼接两个ModuleList
- insert():指定在ModuleList中位置插入网络层
'''-------- ModuleList----------'''
class ModuleList(nn.Module):
def __init__(self):
super(ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])
def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x
net = ModuleList()
print(net)
fake_data = torch.ones((10, 10))
output = net(fake_data)
print(output)
3.3 nn.ModuleDict
nn.ModuleDict是nn.module的容器,用于包装一组网络层,以索引方式调用网络层。
主要方法:
- clear():清空ModuleDict
- items():返回可迭代的键值对(key-value pairs)
- keys():返回字典的键(key)
- values():返回字典的值(values)
- pop():返回一对键值,并从字典中删除
'''-----------ModuleDict---------------'''
class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
})
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu') # 在这里可以选择我们的层进行组合
print(output)
容器总结:
nn.sequential:顺序性,各网络层之间严格按照顺序执行,常用于block构建
nn.ModuleList:迭代性,常用于大量重复网络构建,通过for循环实现重复构建
nn.ModuleDict:索引性,常用于可选择的网络层
本文参考:Pytorch学习笔记(4):模型创建(Module)、模型容器(Containers)、AlexNet构建-CSDN博客