pytorch技巧 六: ModuleList和Sequential

pytorch技巧 六: ModuleList和Sequential

在pytorch搭建模型的过程中经常会碰到 ModuleList和Sequential模块,谨以此文记录自己对这两个模块的理解,本人才疏学浅,希望各位不吝赐教。

1. 简介

nn.Sequential:

介绍这个模块前,我们要知道一个重要观点,就是在pytorch中,核心是Module类。而Sequential就是继承自Module类。它就像一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行。
例一:

import torch.nn as nn

class net1(nn.Module):
    def __init__(self):
        super(net1, self).__init__()
        self.seq = nn.Sequential(nn.Conv2d(3, 10, 3),
                                 nn.ReLU(),
                                 nn.Conv2d(10, 32, 3),
                                 nn.ReLU()
                                 )
        self.conv1 = nn.Conv2d(32, 10, 3)
    def forward(self, x):
        x = self.seq(x)
        x = self.conv1(x)
        return x

mymodel = net1()
print(mymodel)



结果:

net1(
  (seq): Sequential(
    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(10, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
  (conv1): Conv2d(32, 10, kernel_size=(3, 3), stride=(1, 1))
)

在这个网络中由两个block组成,self.seq和self.conv1,其中self.seq是由nn.Sequential封装好的,按顺序进行计算,默认按顺序命名(0,1,2…),也可以使用OrderedDict来指定每个module的名字:

import torch.nn as nn
from collections import OrderedDict

class net1(nn.Module):
    def __init__(self):
        super(net1, self).__init__()
        self.seq = nn.Sequential(OrderedDict([('conv0_0', nn.Conv2d(3, 10, 3)),
                                             ('relu0_0', nn.ReLU()),
                                             ('conv0_1', nn.Conv2d(10, 32, 3)),
                                             ('relu0_1', nn.ReLU())]
                                 ))
        self.conv1 = nn.Conv2d(32, 10, 3)
    def forward(self, x):
        x = self.seq(x)
        x = self.conv1(x)
        return x

mymodel = net1()
print(mymodel)



输出:

net1(
  (seq): Sequential(
    (conv0_0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
    (relu0_0): ReLU()
    (conv0_1): Conv2d(10, 32, kernel_size=(3, 3), stride=(1, 1))
    (relu0_1): ReLU()
  )
  (conv1): Conv2d(32, 10, kernel_size=(3, 3), stride=(1, 1))
)

这里还有个特需要注意的地方,就是他的forward函数的写法,等会与nn. ModuleList对比!!!

nn. ModuleList:

这个模块也是继承自Module类,它是一个储存多个 module,并自动将每个 module 的 parameters 添加到网络之中的容器。你可以把任意 nn.Module 的子类 加到这个容器里面,方法和 Python 自带的 list 一样,无非是 extend,append 等操作。但不同于普通list,加入到 nn.ModuleList 里面的 module 是会自动注册到整个网络上的,同时 module 的 parameters 也会自动添加到整个网络中。为什么会有这个模块?因为我们每写一个module,就要写forward函数,module很多的话就比较麻烦。
例二:

import torch.nn as nn
from collections import OrderedDict

class net2(nn.Module):
    def __init__(self):
        super(net2, self).__init__()
        self.modulelist = nn.ModuleList([nn.Conv2d(3, 10, 3),
                                  nn.ReLU(),
                                  nn.Conv2d(10, 32, 3),
                                  nn.ReLU()
                                         ])
        self.conv1 = nn.Conv2d(32, 10, 3)
    def forward(self, x):
        for m in self.modulelist:  # 区别于例一
            x = m(x)
        x = self.conv1(x)
        return x

mymodel = net2()
print(mymodel)



输出:

net2(
  (modulelist): ModuleList(
    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(10, 32, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
  (conv1): Conv2d(32, 10, kernel_size=(3, 3), stride=(1, 1))
)

2. 区别

区别一:
Sequential内部实现了forward函数,因此可以不用写forward函数。而ModuleList则没有实现内部forward函数必须要写。这也就是为什么我强调关注forward函数,仔细看看例一与例二中forward函数,例一是Sequential模块,在forward函数中没有逐一对Sequential模块内部进行forward,而例二是ModuleList模块,在forward函数中逐一对ModuleList内部进行forward!
因为Sequential内部实现了forward函数,我们也可以这样构建模型:

seq = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )
print(seq)

ModuleList不可以这样使用,因为他没有实现forward函数。但是一般我们不这样用,而是把Sequential作为一个block放在Module模块中。让代码看起来更加简洁,可读性高。

区别二:

Sequential模块是按照顺序运算的,所以必须确保前一个模块输出大小和下一个模块输入大小一致。而ModuleList不一定是按照顺序,它的计算顺序跟forward函数顺序相关:

import torch.nn as nn

class net2(nn.Module):
    def __init__(self):
        super(net2, self).__init__()
        self.modulelist = nn.ModuleList([nn.Conv2d(3, 10, 3),
                                  nn.ReLU(),
                                  nn.Conv2d(32, 64, 3),
                                  nn.ReLU(),
                                  nn.Conv2d(10, 32, 3),
                                  nn.ReLU()
                                         ])
        self.conv1 = nn.Conv2d(32, 10, 3)
    def forward(self, x):
        x = self.modulelist[0](x)
        x = self.modulelist[1](x)
        x = self.modulelist[4](x)
        x = self.modulelist[5](x)
        x = self.modulelist[2](x)
        x = self.modulelist[3](x)
        x = self.conv1(x)
        return x

mymodel = net2()
print(mymodel)

输出:

net2(
  (modulelist): ModuleList(
    (0): Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
    (4): Conv2d(10, 32, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
  )
  (conv1): Conv2d(32, 10, kernel_size=(3, 3), stride=(1, 1))
)

但是为了方便使用,一般ModuleList内尽量按照顺序来。

区别三:

Sequential可以使用OrderedDict对每层进行命名,上面有讲过。

区别四:

有很多重复的模块时,使用ModuleList会方便很多:

import torch.nn as nn

class mod(nn.Module):
    def __init__(self, input_size):
        super(mod, self).__init__()
        self.input_size = input_size
        self.conv = nn.Conv2d(self.input_size, self.input_size, 3)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)

        return x


class net3(nn.Module):
    def __init__(self):
        super(net3, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3)
        self.relu1 = nn.ReLU()

        layers = [mod(32) for i in range(3)]
        self.conv = nn.ModuleList(layers)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        for layer in self.conv:
            x = layer(x)
        x = self.conv2(x)
        x = self.relu2(x)
        return x

mymodel = net3()
print(mymodel)
for i, param in enumerate(mymodel.parameters()): # 查看模型参数
    print('第 %d 层参数大小:' % i, param.size())

输出:

net3(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
  (relu1): ReLU()
  (conv): ModuleList(
    (0): mod(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
    )
    (1): mod(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
    )
    (2): mod(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
    )
  )
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (relu2): ReLU()
)0 层参数: <class 'torch.Tensor'> torch.Size([32, 3, 3, 3])1 层参数: <class 'torch.Tensor'> torch.Size([32])2 层参数: <class 'torch.Tensor'> torch.Size([32, 32, 3, 3])3 层参数: <class 'torch.Tensor'> torch.Size([32])4 层参数: <class 'torch.Tensor'> torch.Size([32, 32, 3, 3])5 层参数: <class 'torch.Tensor'> torch.Size([32])6 层参数: <class 'torch.Tensor'> torch.Size([32, 32, 3, 3])7 层参数: <class 'torch.Tensor'> torch.Size([32])8 层参数: <class 'torch.Tensor'> torch.Size([64, 32, 3, 3])9 层参数: <class 'torch.Tensor'> torch.Size([64])

Process finished with exit code 0

可以看出,在ModuleList模块中有三个相同的模块,使用ModuleList比一个一个写方便很多。

  • 1
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值