nn.ModuleList和nn.Sequential有什么区别,例子

参考:

  • PyTorch 中的 ModuleList 和 Sequential: 区别和使用场景https://zhuanlan.zhihu.com/p/64990232

x.1 二者区别

在PyTorch中,nn.Sequentialnn.ModuleList都是用于组合多个神经网络层的容器。它们的主要区别在于:

  • nn.Sequential是按照顺序组合多个神经网络层的容器,因此必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。而nn.ModuleList只是将不同的模块储存在一起,这些模块之间并没有什么先后顺序可言。
  • nn.Sequential不需要写forward()函数,而nn.ModuleList的内部没有实现forward()函数,因此需要使用forward()函数对其进行调用。

下面是一个例子,假设我们有两个神经网络层:fc1 = nn.Linear(10, 20)fc2 = nn.Linear(20, 30)。我们可以使用以下代码将它们组合成一个神经网络:

fc1 = nn.Linear(10, 20)
fc2 = nn.Linear(20, 30)

# 使用 nn.Sequential
model = nn.Sequential(fc1, nn.ReLU(), fc2)

# 使用 nn.ModuleList
layers = [fc1, nn.ReLU(), fc2]
model = nn.ModuleList(layers)

需要注意的是,nn.ModuleList和nn.Sequential都是nn.Module的子类,可以作为其他模型的组件使用。选择使用哪个取决于你的网络结构和需求的灵活性。

x.2 nn.ModuleList 详细例子

nn.ModuleList具有以下特点:

  • nn.ModuleList是一个简单的Python列表,用于存储和管理各个网络层的实例,
  • 它允许以任意顺序添加、移除或访问网络层。
  • 使用nn.ModuleList时,需要手动编写前向传播函数(forward)来定义网络的计算逻辑。
  • 适用于需要更灵活的网络结构,如跳跃连接(skip connections)或者条件连接。

下面是一个使用nn.ModuleList的简单示例:

import torch
import torch.nn as nn

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(10, 20), nn.ReLU(), nn.Linear(20, 5)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = MyModule()
input = torch.randn(32, 10)
output = model(input)
print(output.shape)

x.3 nn.Sequential

x.3.1 nn.Sequential 详细例子

nn.Sequential具有以下特点:

  • nn.Sequential是一个按顺序组织网络层的容器类,可以通过简单地传递网络层实例列表来构建模型。
  • 它自动定义了前向传播函数,无需手动编写。
  • nn.Sequential的顺序性限制了网络层的连接方式,每一层的输入都是上一层的输出。
  • 适用于按顺序连接的简单网络结构。

下面是一个使用nn.Sequential的简单示例:

import torch
import torch.nn as nn

model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 5)
)

input = torch.randn(32, 10)
output = model(input)
print(output.shape)

x.3.2 nn.Sequential 源码实现

在d2l的layer和module(层和块)章节曾谈论到过nn.Sequential如何书写,使用torch.nn.Module.add_module()方法和torch.nn.Module.children()方法,源代码如下:

class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            self.add_module(str(idx), module)

    def forward(self, X):
        for module in self.children():
            X = module(X)
        return X

net = MySequential(nn.LazyLinear(256), nn.ReLU(), nn.LazyLinear(10))
net(X).shape
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值