Pytorch中nn.Module和nn.Sequencial的简单学习


前言

   目前在学习 Pytorch 入门,很久之前进行了自定义模型的编码,但因为学业繁忙,时隔一周再来继续对 Pytorch 的学习,以及之前对 Python 的学习并不扎实,回过头再来看之前的代码需要再次理解,浪费时间,所以写下本博客对知识理解进行记录,也便后续回忆。

1、Python 类

下面介绍一些后续会用到的关于 Python 类的知识点:

  • __init__() 方法是一种特殊的方法,被称为类的构造函数或初始化方法,当创建了这个类的实例时就会调用该方法。
  • self 代表类的实例,self在定义类的方法时是必须有的,虽然在调用时不必传入相应的参数。
  • 类的方法和普通的函数只有一个特别的区别----它们必须有一个额外的第一个参数名称,按照惯例,它的名称是 self,当然换成其他名称也是可以的。

2、nn.Module 和 nn.Sequential

   该部分主要参考下面两条blog,个人感觉感jio很不错:

2.1 nn.Module

  Pytorch 中没有特别明显的 LayerModule 的区别,不管是自定义层、自定义块、自定义模型,都是通过继承 Module 类完成的,这一点很重要。其实 Sequential 类也是继承自 Module 类的。

  pytorch 里面一切自定义操作基本上都是继承自 nn.Module 类实现的。

2.1.1 torch.nn.Module类

先看源码:

class Module(object):
    def __init__(self):
    def forward(self, *input):
 
    def add_module(self, name, module):
    def cuda(self, device=None):
    def cpu(self):
    def __call__(self, *input, **kwargs):
    def parameters(self, recurse=True):
    def named_parameters(self, prefix='', recurse=True):
    def children(self):
    def named_children(self):
    def modules(self):  
    def named_modules(self, memo=None, prefix=''):
    def train(self, mode=True):
    def eval(self):
    def zero_grad(self):
    def __repr__(self):
    def __dir__(self):
# ...

  在自定义模型的时候,我们需要继承 nn.Module 类,并且需要重新实现构造函数 __init__() 方法和前向传播 forward() 方法。但有一些技巧需要注意:

  1. 一般把网络中具有可学习参数的层(如全连接层、卷积层等)放在构造函数 __init__() 中,当然也可以把不具有参数的层也放在里面;

  2. 一般把不具有可学习参数的层(如ReLUdropoutBatchNormanation层)可放在构造函数中,也可不放在构造函数中,如果不放在构造函数 __init__() 里面,则在 forward() 方法里面可以使用 nn.functional 来代替

  3. forward() 方法是必须要重写的,它是实现模型的功能,实现各个层之间的连接关系的核心。

  4. __init__() 方法中定义一系列层,此时层与层之间并没有连接关系,而在 forward() 方法中实现所有层的链接关系。

2.1.2 nn.Sequential 类

  nn.Sequential 类继承自nn.Module类,先来看定义:

class Sequential(Module): # 继承Module
    def __init__(self, *args):  # 重写了构造函数
    def _get_item_by_idx(self, iterator, idx):
    def __getitem__(self, idx):
    def __setitem__(self, idx, module):
    def __delitem__(self, idx):
    def __len__(self):
    def __dir__(self):
    def forward(self, input):  # 重写关键方法forward

  Sequential类的三种实现:

  1. 最简单的顺序模型
import torch.nn as nn
model = nn.Sequential(
                  nn.Conv2d(1,20,5),
                  nn.ReLU(),
                  nn.Conv2d(20,64,5),
                  nn.ReLU()
                )
 
print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential(
  (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU()
  (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
  • 每个层没有名称,默认通过0、1、2、3来命名。
  1. 给每一个层添加名称(orderedDict)
import torch.nn as nn

from collections import OrderedDict
model = nn.Sequential(OrderedDict([
                  ('conv1', nn.Conv2d(1,20,5)),
                  ('relu1', nn.ReLU()),
                  ('conv2', nn.Conv2d(20,64,5)),
                  ('relu2', nn.ReLU())
                ]))
 
print(model)
print(model[2]) # 通过索引获取第几个层
'''运行结果为:
Sequential(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (relu1): ReLU()
  (conv2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
  (relu2): ReLU()
)
Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
'''
  • 从结果可以看出,此时每一层都有了自己的名字,但不能以名字来进行层的索引

model[2] →正确
model["conv2]→错误

  1. 第三种实现(add_module)
import torch.nn as nn
from collections import OrderedDict

model = nn.Sequential()
model.add_module("conv1",nn.Conv2d(1,20,5))
model.add_module('relu1', nn.ReLU())
model.add_module('conv2', nn.Conv2d(20,64,5))
model.add_module('relu2', nn.ReLU())
 
print(model)
print(model[2]) # 通过索引获取第几个层
  • Sequential 类并没有定义 add_module() 方法,实际上这个方法是定义在它的父类 Module 里面的,Sequential 继承了该方法。它的定义如下:
def add_module(self, name, module)

3.自己的示例

  再看当初自己写的代码,便不难理解了:

import torch.nn as nn

class LinearNet(nn.Module):
    def __init__(self, n_feature):
        
        # 这是对继承自父类的属性进行初始化。而且是用父类的初始化方法来初始化继承的属性。
        # 也就是说,子类继承了父类的所有属性和方法,父类属性自然会用父类方法来进行初始化。
        # 当然,如果初始化的逻辑与父类的不同,不使用父类的方法,自己重新初始化也是可以的。
        super(LinearNet, self).__init__()
        self.linear = nn.Linear(n_feature, 1)
    
    # 前向传播
    def forward(self, x):
        y = self.linear(x)
        return y
    
net = LinearNet(2)
print(net)           # 使用print可以打印出网络的结构
  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值