官方参考链接:torch.nn.MoudleList
以List形式保存submodule
ModuleList可以像常规的Python列表一样被索引,但它包含的模块是正确注册的,并且所有Module方法都可以看到。
import torch
import torch.nn as nn
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList can act as an iterable, or be indexed using ints
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
test_module = MyModule()
test_module.train()
input = torch.tensor(range(10), dtype=torch.float)
out = test_module(input)
debug:
i=0, l为第一个线性层
i=1, l为第二个线性层。