模型容器(Containers)
nn.ModuleList
nn.ModuleList是nn.module的容器,用于包装一组网络层,以迭代的方式调用网络层。
主要方法:
- append(): 在ModuleList后面添加网络层
- extend(): 拼接两个ModuleList
- insert(): 指定在ModuleList中位置插入网络层
class ModuleList(nn.Module):
def __init__(self):
super(ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)])
def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x
代码调试
1、在95行设置断点,Debug进入ModuleList类中。
2、运行至87行,点击二十次Run to Cursor,然后点击步入,进入nn.ModuleList类中进行观察。
3、此时我们进入到ModuleList类中的__init__()中,如果module不为空,则会执行self += modules,这就是List的方法
此时我们的modules为List,所以程序会对ModuleList进行拼接。
4、返回到主程序我们就构建好了一个ModuleList
在forward里面我们就可以通过for循环获取ModuleList中的每一个网络层,这样就实现了20层的全连接网络层。