nn.ModuleList()、 nn.Sequential()和 nn.ModuleDict()区别

nn.ModuleList()nn.Sequential(), 和 nn.ModuleDict() 都是 PyTorch 中用于组织和管理神经网络模块的容器,但它们在使用方式和功能上有显著的区别:

nn.ModuleList()

nn.ModuleList() 是一个类似于 Python 列表的容器,用于存储任意数量的 nn.Module 子模块。它提供了列表的操作,如添加、删除或索引访问模块,但所有子模块都被视为模型的一部分,可以被优化器找到。

特点:

  • 动态管理:适合动态添加或删除模块的情况。
  • 非顺序执行:模块的执行顺序不是固定的,而是由你在 forward 方法中遍历列表的顺序决定。
  • 索引访问:通过数字索引访问子模块。
import torch
from torch import nn

class CustomModel(nn.Module):
    def __init__(self):
        super(CustomModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 10), nn.ReLU(), nn.Linear(10, 5)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

nn.Sequential

nn.Sequential 是一个有序的容器,用于按顺序组合多个 nn.Module 子模块。它简化了前向传播的过程,数据会自动通过这些模块,按照它们在容器中添加的顺序进行处理。

特点:

  • 自动前向传播:一旦模块被添加,前向传播会自动进行,无需在 forward 方法中显式调用。
  • 顺序执行:模块按添加顺序执行,适合线性堆叠的模型结构。
  • 无键名访问:模块只能通过索引访问,没有键名。
import torch
from torch import nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(10, 10),
            nn.ReLU(),
            nn.Linear(10, 5)
        )

    def forward(self, x):
        return self.model(x)

nn.ModuleDict()

nn.ModuleDict() 是一个类似于 Python 字典的容器,用于存储一系列 nn.Module 子模块,每个子模块可以通过字符串键来访问。它提供了字典操作,如添加、删除或键名访问模块。

特点:

  • 命名访问:通过键名访问子模块,适合需要按名称引用模块的情况。
  • 动态管理:可以动态添加或删除模块,类似于 ModuleList 的灵活性。
  • 非顺序执行:虽然模块可以通过键名明确访问,但执行顺序仍需在 forward 方法中指定。
import torch
from torch import nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        # 创建一个ModuleDict,并添加三个线性层,每个层都有一个唯一的键名
        self.linears = nn.ModuleDict({
            'linear1': nn.Linear(10, 10),
            'linear2': nn.Linear(10, 5),
            'linear3': nn.Linear(5, 2)
        })

    def forward(self, x):
        # 按名称访问ModuleDict中的子模块
        x = self.linears['linear1'](x)
        x = self.linears['linear2'](x)
        x = self.linears['linear3'](x)
        return x

# 创建模型实例
model = MyModel()

# 创建一个随机输入张量
input_data = torch.randn(20, 10)

# 将输入数据传递给模型
output_data = model(input_data)

# 打印模型的参数
for name, param in model.named_parameters():
    print(name, param.size())

总结

  • 灵活性nn.ModuleList() 和 nn.ModuleDict() 提供了更多的灵活性,允许动态管理和非顺序执行。
  • 访问方式nn.ModuleList() 使用索引访问,而 nn.ModuleDict() 使用键名访问。
  • 前向传播nn.Sequential 自动处理前向传播,而 nn.ModuleList() 和 nn.ModuleDict() 需要在 forward 方法中显式调用每个模块。

选择哪个容器取决于你的模型结构和需求。对于简单的线性堆叠模型,nn.Sequential 是最直接的选择。对于需要更多控制和灵活性的模型,nn.ModuleList() 或 nn.ModuleDict() 更为合适,具体选择取决于你是否需要通过键名访问模块。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值