知乎文章链接:https://zhuanlan.zhihu.com/p/64990232
class net6(nn.Module):
def __init__(self):
super(net6, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(3)])
def forward(self, x):
for layer in self.linears:
x = layer(x)
return x
net = net6()
print(net)
# net6(
# (linears): ModuleList(
# (0): Linear(in_features=10, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# (2): Linear(in_features=10, out_features=10, bias=True)
# )
# )
这个是比较一般的方法,但如果不想这么麻烦,我们也可以用 Sequential 来实现,如 net7 所示!注意 * 这个操作符,它可以把一个 list 拆开成一个个独立的元素。但是,请注意这个 list 里面的模块必须是按照想要的顺序来进行排列的。在 场景一 中,我个人觉得使用 net7 这种方法比较方便和整洁。
class net7(nn.Module):
def __init__(self):
super(net7, self).__init__()
self.linear_list = [nn.Linear(10, 10) for i in range(3)]
self.linears = nn.Sequential(*self.linears_list)
def forward(self, x):
self.x = self.linears(x)
return x
net = net7()
print(net)
# net7(
# (linears): Sequential(
# (0): Linear(in_features=10, out_features=10, bias=True)
# (1): Linear(in_features=10, out_features=10, bias=True)
# (2): Linear(in_features=10, out_features=10, bias=True)
# )
# )
class net5(nn.Module):
def __init__(self):
super(net5, self).__init__()
self.block = nn.Sequential(nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU())
def forward(self, x):
x = self.block(x)
return x
net = net5()
print(net)
# net5(
# (block): Sequential(
# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (1): ReLU()
# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (3): ReLU()
# )
# )