Pytorch中nn.ModuleList和nn.Sequential的用法和区别

最近在定义一个多任务网络的时候对nn.ModuleList和nn.Sequential的用法产生了疑惑,这里让我们一起来探究一下二者的用法和区别。

nn.ModuleList和nn.Sequencial的作用
先来探究一下nn.ModuleList的作用,定义一个简单的只含有全连接的网络来看一下。当不使用ModuleList只用list来定义网络中的层的时候:

import torch
import torch.nn as nn
 
class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.combine = []
        self.combine.append(nn.Linear(100,50))
        self.combine.append(nn.Linear(50,25))
 
net = testNet()
print(net)


可以看到结果并没有显示添加的全连接层信息。如果改成采用ModuleList:

import torch
import torch.nn as nn
 
class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.combine = nn.ModuleList()
        self.combine.append(nn.Linear(100,50))
        self.combine.append(nn.Linear(50,25))
 
net = testNet()
print(net)


可以看到定义的时候pytorch可以自动识别nn.ModuleList中的参数而普通的list则不可以。

并且用Sequential也能达到和ModuleList同样的效果。

nn.ModuleList和nn.Sequential的区别
那ModuleList和Sequential有什么区别呢,如果我们用nn.ModuleList定义一个网络,并给网络定义一个输入,查看输出结果:

import torch
import torch.nn as nn
 
class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.combine = nn.ModuleList()
        self.combine.append(nn.Linear(100,50))
        self.combine.append(nn.Linear(50,25))
 
testnet = testNet()
input_x = torch.ones(100)
output_x = testnet(input_x) 
print(output_x)
会报错NotImplementedError:

这是因为没有实现forward()方法,如果将forward()方法补全如下:

import torch
import torch.nn as nn
 
class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.combine = nn.ModuleList()
        self.combine.append(nn.Linear(100,50))
        self.combine.append(nn.Linear(50,25))
    #补全forward()
    def forward(self, x):
        x = self.combine(x)
 
        return x
 
testnet = testNet()
input_x = torch.ones(100)
output_x = testnet(input_x) 
print(output_x)
发现仍旧会报NotImplementedError错误。这是因为nn.ModuleList是一个无序性的序列,他并没有实现forward()方法,我们不能通过直接调用x = self.combine(x)的方法来实现forward()。如果想要实现ModuleList的方法需要如下定义forward():

import torch
import torch.nn as nn
 
class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.combine = nn.ModuleList()
        self.combine.append(nn.Linear(100,50))
        self.combine.append(nn.Linear(50,25))
    #重新定义forward()方法
    def forward(self, x):
        x = self.combine[0](x)
        x = self.combine[1](x)
 
        return x
 
testnet = testNet()
input_x = torch.ones(100)
output_x = testnet(input_x) 
print(output_x)


得到了正确的结果。如果替换成nn.Sequential定义:

import torch
import torch.nn as nn
 
class testNet(nn.Module):
    def __init__(self):
        super(testNet, self).__init__()
        self.combine = nn.Sequential(
            nn.Linear(100,50),
            nn.Linear(50,25),
        ) 
    
    def forward(self, x):
        x = self.combine(x)
 
        return x
 
testnet = testNet()
input_x = torch.ones(100)
output_x = testnet(input_x) 
print(output_x)


也能得到同样的结果。

我查阅了一些资料是这么说的:

In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module

翻译一下,意思就是说:

nn.Sequential定义的网络中各层会按照定义的顺序进行级联,因此需要保证各层的输入和输出之间要衔接。并且nn.Sequential实现了farward()方法,因此可以直接通过类似于x=self.combine(x)的方式实现forward。

而nn.ModuleList则没有顺序性要求,并且也没有实现forward()方法。

这是二者之间的区别。

两者实际都为容器、ModuleList没有顺序性要求,并且没有实现forward()方法。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值