Pytorch中Sequential子类实现

简述

当模型的前向计算为简短串联各层的计算时, S e q u e n t i a l Sequential Sequential类可以通过更加简单的方式定义模型,这正是 S e q u e n t i a l Sequential Sequential类的目的:它可以接收一个子模块的有序字典 ( O r d e r e d D i c t ) (OrderedDict) (OrderedDict)或者一系列子模块作为参数来逐一添加 M o d u l e Module Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算.
下面我们实现一个与 S e q u e n t i a l Sequential Sequential类有相同功能的 M y S e q u e n t i a l MySequential MySequential类.这或许可以帮助更加清晰地理解 S e q u e n t i a l Sequential Sequential类地工作机制.
MySequential

from collections import OrderedDict

import torch
from torch import nn

class MySequential(nn.Module):
    from collections import OrderedDict
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):   #如果传入的是一个OrderedDict
            for key, module in args[0].items():
                self.add_module(key, module)     # add_module方法会将module添加进self._modules(一个OrderedDict)
        else:  # 传入的是一些Module
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    def forward(self, input):
        # self.__modules返回一个OrderedDict, 保证会按照成员添加时的顺序遍历成员
        for module in self._modules.values():
            input = module(input)
        return input

net = MySequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

X = torch.rand(2, 784)
print(net)
print(net(X))

在这里插入图片描述

Sequential

from torch.nn import init
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l

import torch
from torch import nn


num_inputs, num_hiddens, num_outputs = 784, 256, 10
net = nn.Sequential(
    d2l.FlattenLayer(),
    nn.Linear(num_inputs, num_hiddens),
    nn.ReLU(),
    nn.Linear(num_hiddens, num_outputs)
)
for param in net.parameters():
    init.normal_(param, mean = 0, std = 0.01)
    
X = torch.rand(2, 784)
print('*'*20)
print(net)
print(net(X))

在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值