一、 nn.Module()模块
所有网络层的基类,管理网络属性。
一个module可以包含多个子module.
一个module相当于一个运算,必须实现forward()方法,而backward函数会被自动实现(利用Autograd).
每个module都有8个字典管理其属性.
二、nn.Sequential()容器
nn.Sequential()是nn.module()的容器,用于按顺序包装一组网络层。
顺序性:各个网络层之间严格按照顺序构造。
自带forward():自带的forward中,通过for循环依次执行前向传播运算。
三、nn.ModuleList()容器
nn.ModuleList()是nn.module()的容器,用于包装一组网络层,以迭代的方式调用网络层。
append():在ModuleList后面添加网络层.
extend():拼接两个Modulelist.
insert():在ModuleList的指定位置插入网络层.
四、nn.ModuleDict()容器
nn.ModuleDict()是nn.module()的容器,用于包装一组网络层,以索引的方式调用网络层。
clear():清空ModuleDict.
items():返回可迭代的键值对(key-value pairs).
keys():返回字典的键(key).
values():返回字典的值(value).
pop():返回一对键值,并从字典中删除.
三、举例
1、nn.ModuleList():
import torch
import torchvision
import torch.nn as nn
from collections import OrderedDict
class LeNetSequence(nn.Module):
def __init__(self, classes):
super(LeNetSequence, self).__init__()
self.features = nn.Sequential( # 向容器中添加多各个网络层
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2))
self.classifier = nn.Sequential(
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.Linear(84, classes))
def forward(self, x): # 在类中写forward()函数后,可在主函数直接调用LeNet()类即实现了前向传播和反向传播.
x = self.features(x)
x = x.view(x.size()[0], -1) # 和reshape方法差不多
x = self.classifier(x)
return x
class LeNetSequentialOrderDict(nn.Module):
def __init__(self, classes):
super(LeNetSequentialOrderDict, self).__init__()
self.features = nn.Sequential(OrderedDict({ # OrderedDict可以实现自命名各个网络层
'conv1': nn.Conv2d(3, 6, 5),
'relu1': nn.ReLU(inplace=True),
'pool1': nn.MaxPool2d(kernel_size=2, stride=2),
'conv2': nn.Conv2d(6, 16, 5),
'relu2': nn.ReLU(inplace=True),
'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
}))
self.classifier = nn.Sequential(OrderedDict({
'fc1': nn.Linear(16 * 5 * 5, 120),
'relu3': nn.ReLU(),
'fc2': nn.Linear(120, 84),
'relu4': nn.ReLU(inplace=True),
'fc3': nn.Linear(84, classes),
}))
def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x
# net = LeNetSequence(classes=2)
net = LeNetSequentialOrderDict(classes=2)
fake_img = torch.randn((4, 3, 32, 32), dtype=torch.float32) # 输入的Tensor数据
output = net(fake_img)
# print(net) # 打印网络结构
print(output) # 输出结果
运行结果:
tensor([[ 0.0364, -0.0334],
[ 0.0380, -0.0300],
[ 0.0315, -0.0312],
[ 0.0330, -0.0234]], grad_fn=<AddmmBackward>)
2、nn.ModuleList():
class ModuleList(nn.Module):
def __init__(self):
super(ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])
def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x
net = ModuleList()
print(net)
fake_data = torch.ones((10, 10))
output = net(fake_data)
运行结果:
ModuleList(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
(2): Linear(in_features=10, out_features=10, bias=True)
(3): Linear(in_features=10, out_features=10, bias=True)
(4): Linear(in_features=10, out_features=10, bias=True)
(5): Linear(in_features=10, out_features=10, bias=True)
(6): Linear(in_features=10, out_features=10, bias=True)
(7): Linear(in_features=10, out_features=10, bias=True)
(8): Linear(in_features=10, out_features=10, bias=True)
(9): Linear(in_features=10, out_features=10, bias=True)
(10): Linear(in_features=10, out_features=10, bias=True)
(11): Linear(in_features=10, out_features=10, bias=True)
(12): Linear(in_features=10, out_features=10, bias=True)
(13): Linear(in_features=10, out_features=10, bias=True)
(14): Linear(in_features=10, out_features=10, bias=True)
(15): Linear(in_features=10, out_features=10, bias=True)
(16): Linear(in_features=10, out_features=10, bias=True)
(17): Linear(in_features=10, out_features=10, bias=True)
(18): Linear(in_features=10, out_features=10, bias=True)
(19): Linear(in_features=10, out_features=10, bias=True)
)
)
3、nn.ModuleDict():
class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
})
def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x
net = ModuleDict()
fake_img = torch.randn((4, 10, 32, 32))
output = net(fake_img, 'conv', 'relu') # 根据key值选取value