在PyTorch中,nn容器(nn containers)是用于组织和管理神经网络层的工具,它们使得构建和管理复杂的神经网络结构变得更加简单和高效。以下是几个常用的nn容器及其特点:
1. nn.Sequential
- 定义:
nn.Sequential
是一个有序的容器,它按照传入构造器的顺序,依次创建相应的网络层,并将它们封装成一个整体。 - 特点:
- 有序性:内部的层按照传入的顺序进行排列,前一个层的输出自动作为后一个层的输入。
- 简化代码:使用
nn.Sequential
可以简化代码,避免显式地编写前向传播逻辑。 - 灵活性受限:由于内部层的顺序是固定的,因此在需要复杂前向传播逻辑时可能不够灵活。
- 示例:
-
import torch.nn as nn model = nn.Sequential( nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10), nn.LogSoftmax(dim=1) ) print(model) 输出结果: Sequential( (0): Linear(in_features=784, out_features=128, bias=True) (1): ReLU() (2): Linear(in_features=128, out_features=10, bias=True) (3): LogSoftmax(dim=1) )
2. nn.ModuleList
- 定义:
nn.ModuleList
是一个持有多个子模块的列表,它继承自nn.Module
,因此可以被视为一个特殊的列表容器。 - 特点:
- 迭代性:可以像普通的Python列表一样迭代
nn.ModuleList
中的模块。 - 自动注册:添加到
nn.ModuleList
中的模块会被自动注册到整个网络中,这意味着它们的参数会被包含在net.parameters()
中。 - 无前向传播:与
nn.Sequential
不同,nn.ModuleList
本身不定义前向传播逻辑,需要在自定义的forward
方法中实现。
- 迭代性:可以像普通的Python列表一样迭代
- 示例:
-
import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.linears = nn.ModuleList([nn.Linear(10, 10) for _ in range(4)]) def forward(self, x): for linear in self.linears: x = linear(x) return x model = MyModule() print(model) 输出结果: MyModule( (linears): ModuleList( (0): Linear(in_features=10, out_features=10, bias=True) (1): Linear(in_features=10, out_features=10, bias=True) (2): Linear(in_features=10, out_features=10, bias=True) (3): Linear(in_features=10, out_features=10, bias=True) ) )
3. nn.ModuleDict
- 定义:
nn.ModuleDict
是一个持有多个子模块的字典,它同样继承自nn.Module
。 - 特点:
- 命名索引:通过键值对的方式存储模块,可以通过键名来索引和访问模块。
- 自动注册:与
nn.ModuleList
类似,添加到nn.ModuleDict
中的模块也会被自动注册到整个网络中。 - 灵活性:允许根据需求动态地添加、删除或修改模块。
- 示例:
import torch.nn as nn class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.layers = nn.ModuleDict({ 'conv1': nn.Conv2d(1, 20, 5), 'pool': nn.MaxPool2d(2, 2), 'conv2': nn.Conv2d(20, 50, 5) }) def forward(self, x): x = self.layers['conv1'](x) x = self.layers['pool'](x) x = self.layers['conv2'](x) return x model = MyModule() print(model) 输出结果: MyModule( (layers): ModuleDict( (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False) (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1)) ) )