(pytorch-深度学习系列)pytorch构造深度学习模型-学习笔记

pytorch构造深度学习模型

1. 通过继承module类的方式来构造模型

Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类。
可以继承基类并重构 __init()__函数和 f o r w a r d ( ) forward() forward()函数的方式来构造模型。

以下是一个构造一个模型的例子:

import torch
from torch import nn

class MLP(nn.Module):
    # 声明带有模型参数的层,这里声明了两个全连接层
    def __init__(self, **kwargs):
        # 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256) # 隐藏层
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)  # 输出层

    # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

实例化刚刚构建的MLP类得到模型变量net,net(X)会调用MLP继承自Module类的__call__函数,这个函数将调用MLP类定义的forward函数来完成前向计算。

X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)

2. 通过Sequential类定义简单的模型

如果定义的模型的前向计算就是简单的串联各层的计算时,可以通过Sequential类快速定义模型。它可以接收一个子模块的有序字典(OrderedDict)或者一系列子模块作为参数来逐一添加Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算。

定义一个与Sequential类有相同功能的类:ep_sequential来解读Sequential类的工作机制:

class ep_sequential(nn.Module):
    from collections import OrderedDict
    def __init__(self, *args):
        super(ep_sequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict): # 如果传入的是一个OrderedDict
            for key, module in args[0].items():
                self.add_module(key, module)  # add_module方法会将module添加进self._modules(一个OrderedDict)
        else:  # 传入的是一些Module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)
                
    def forward(self, input):
        # self._modules返回一个 OrderedDict,保证会按照成员添加时的顺序遍历成员
        for module in self._modules.values():
            input = module(input)
        return input

这里通过forward函数就可以看出Sequential类实现的是简单的串联各层

net = ep_equential(
        nn.Linear(128, 256),
        nn.ReLU(),
        nn.Linear(256, 10), 
        )
print(net)
net(X)

3. 使用ModuleList类

ModuleList接收一个子模块的列表作为输入,可以像List那样进行append和extend操作:

net = nn.ModuleList([nn.Linear(128, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1])  # 类似List的索引访问
print(net)
# net(torch.zeros(1, 128)) # 会报NotImplementedError

ModuleList仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现forward功能需要自己实现,所以上面执行net(torch.zeros(1, 128))会报NotImplementedError;而Sequential内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部forward功能已经实现。

这是官网的例子:

class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

但是,ModuleList并不是一般的list,加入到ModuleList里的所有模块的参数会被自动添加到网络中。
下面的例子进行了对比:

class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10)])

class Module_List(nn.Module):
    def __init__(self):
        super(Module_List, self).__init__()
        self.linears = [nn.Linear(10, 10)]

net1 = Module_ModuleList()
net2 = Module_List()

print("net1:")
for p in net1.parameters():
    print(p.size())

print("net2:")
for p in net2.parameters():
    print(p)

输出:

net1:
torch.Size([10, 10])
torch.Size([10])
net2:

从结果可以看出,使用ModuleList 初始化网络层,那么该层参数就会被自动加入调用ModuleList的网络的模型中。

4. 使用ModuleDict类

ModuleDict接收一个子模块的字典作为输入, 类似ModuleList,可以像字典那样进行添加访问操作:

net = nn.ModuleDict({
    'linear': nn.Linear(128, 256),
    'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
# net(torch.zeros(1, 128)) # 会报NotImplementedError

和ModuleList一样,ModuleDict实例仅仅是存放了一些模块的字典,并没有定义forward函数。同样,ModuleDict也与Python的普通的Dict不同,ModuleDict里的所有模块的参数会被自动添加到整个网络中。

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值