ModuleList是Module的子类,当在Module中使用它时,就能自动识别为子module。
输入:
from torch import nn
import torch as t
from torch.autograd import Variable as V
class MyModule(nn.Module):
def __init__(self):
super(MyModule,self).__init__()
self.list=[nn.Linear(3,4),nn.ReLU()]
self.module_list=nn.ModuleList([nn.Conv2d(3,3,3),nn.ReLU()])
def forward(self):
pass
# list 中的子module并不能被主module识别,而ModuleList中的子module能
#够被主module识别。这意味着如果用list保存子module,将无法调整其参数,
# 因为其未能加入到主module的参数中
model=MyModule()
print("---MyModule---")
print(model)
for name,param in model.named_parameters():
print(name,param.size())
model=MyModule().list
print(model)
# error
# for name,param in model.named_parameters():
# print(name,param.size())
输出:
---MyModule---
MyModule(
(module_list): ModuleList(
(0): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
(1): ReLU()
)
)
module_list.0.weight torch.Size([3, 3, 3, 3])
module_list.0.bias torch.Size([3])
[Linear(in_features=3, out_features=4, bias=True), ReLU()]
Process finished with exit code 0