简述
当模型的前向计算为简短串联各层的计算时,
S
e
q
u
e
n
t
i
a
l
Sequential
Sequential类可以通过更加简单的方式定义模型,这正是
S
e
q
u
e
n
t
i
a
l
Sequential
Sequential类的目的:它可以接收一个子模块的有序字典
(
O
r
d
e
r
e
d
D
i
c
t
)
(OrderedDict)
(OrderedDict)或者一系列子模块作为参数来逐一添加
M
o
d
u
l
e
Module
Module的实例,而模型的前向计算就是将这些实例按添加的顺序逐一计算.
下面我们实现一个与
S
e
q
u
e
n
t
i
a
l
Sequential
Sequential类有相同功能的
M
y
S
e
q
u
e
n
t
i
a
l
MySequential
MySequential类.这或许可以帮助更加清晰地理解
S
e
q
u
e
n
t
i
a
l
Sequential
Sequential类地工作机制.
MySequential
from collections import OrderedDict
import torch
from torch import nn
class MySequential(nn.Module):
from collections import OrderedDict
def __init__(self, *args):
super(MySequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict): #如果传入的是一个OrderedDict
for key, module in args[0].items():
self.add_module(key, module) # add_module方法会将module添加进self._modules(一个OrderedDict)
else: # 传入的是一些Module
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
# self.__modules返回一个OrderedDict, 保证会按照成员添加时的顺序遍历成员
for module in self._modules.values():
input = module(input)
return input
net = MySequential(
nn.Linear(784, 256),
nn.ReLU(),
nn.Linear(256, 10)
)
X = torch.rand(2, 784)
print(net)
print(net(X))
Sequential
from torch.nn import init
import sys
sys.path.append("..")
import d2lzh_pytorch as d2l
import torch
from torch import nn
num_inputs, num_hiddens, num_outputs = 784, 256, 10
net = nn.Sequential(
d2l.FlattenLayer(),
nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
nn.Linear(num_hiddens, num_outputs)
)
for param in net.parameters():
init.normal_(param, mean = 0, std = 0.01)
X = torch.rand(2, 784)
print('*'*20)
print(net)
print(net(X))