nn.Sequential
里面的模块按照顺序进行排列,所以必须确保前一个模块的输出大小和下一个模块的输入大小是一致的。内部已经实现forward函数,因此不用写forward函数。但在继承nn,Module类
#例1:这是来自官方文档的例子
seq = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
print(seq)
# Sequential(
# (0): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
# (1): ReLU()
# (2): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
# (3): ReLU()
# )
#对上述seq进行输入
input = torch.randn(16, 1, 20, 20)
print(seq(input))
#torch.Size([16, 64, 12, 12])
#例2:或者继承nn.Module类的话,就要写出forward函数
class net1(nn.Module):
def __init__(self):
super(net1, self).__init__()
self.seq = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
def forward(self, x):
return self.seq(x)
#注意:按照下面这种利用for循环的方式也是可以得到同样结果的
#def forward(self, x):
# for s in self.seq:
# x = s(x)
# return x
#对net1进行输入
input = torch.randn(16, 1, 20, 20)
net1 = net1()
print(net1(input).shape)
#torch.Size([16, 64, 12, 12])
nn.ModuleList
储存不同module,并自动将每个module的parameters添加到网络之中的容器。没有实现内部forward函数
#例2:写出forward函数
class net2(nn.Module):
def __init__(self):
super(net2, self).__init__()
self.modlist = nn.ModuleList([
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
])
#这里若按照这种写法则会报NotImplementedError错
#def forward(self, x):
# return self.modlist(x)
#注意:只能按照下面利用for循环的方式
def forward(self, x):
for m in self.modlist:
x = m(x)
return x
input = torch.randn(16, 1, 20, 20)
net2 = net2()
print(net2(input).shape)
#torch.Size([16, 64, 12, 12])