- 报错:
AttributeError: cannot assign module before Module.__init__() call
- 错误代码:
class ModuleDict(nn.Module):
def __init__(self, choice, act):
self.choice=choice
self.act = act
self.choices = nn.ModuleDict({
'conv1':nn.Conv2d(10, 10, 3),
'conv2':nn.Conv2d(10, 20, 3),
})
self.pool = nn.MaxPool2d(2)
self.activations = nn.ModuleDict({
'relu':nn.ReLU(),
'prelu':nn.PReLU()
})
def forward(self, x):
x = self.choices[self.choice](x)
x = self.pool(x)
x = self.activations[self.act](x)
return x
net1 = ModuleDict('conv1', 'relu')
net2 = ModuleDict('conv2', 'relu')
summary(net1,input_size=(10, 32, 32))
summary(net2,input_size=(10, 32, 32))
-
错误原因:
继承了父类(nn.Module)后子类重写了__init__
,但是没有调用super初始化父类的构造函数。 -
修改:添加
super(className, self).__init__()
代码 -
修改后代码:
class ModuleDict(nn.Module):
def __init__(self, choice, act):
super(ModuleDict, self).__init__()
self.choice=choice
self.act = act
self.choices = nn.ModuleDict({
'conv1':nn.Conv2d(10, 10, 3),
'conv2':nn.Conv2d(10, 20, 3),
})
self.pool = nn.MaxPool2d(2)
self.activations = nn.ModuleDict({
'relu':nn.ReLU(),
'prelu':nn.PReLU()
})
def forward(self, x):
x = self.choices[self.choice](x)
x = self.pool(x)
x = self.activations[self.act](x)
return x
net1 = ModuleDict('conv1', 'relu')
net2 = ModuleDict('conv2', 'relu')
summary(net1,input_size=(10, 32, 32))
summary(net2,input_size=(10, 32, 32))
运行结果