自定义了一个Linear
类,
并用self.add_module('L1',nn.Linear(3,2))
添加了一层线性变换,
class Linear(nn.Module):
def __init__(self) :
super(Linear,self).__init__()
self.add_module('L1',nn.Linear(3,2))
self.add_module('L2',nn.Linear(2,1))
self.add_module('S3',nn.Sigmoid())
然后想要获取权重
LLL=Linear()
print(LLL[0].weight)
就报了这样的错误:TypeError: 'Linear' object is not subscriptable
然而用nn.Sequential() 定义模型时却不会有这样的问题
所以要怎么解决呢?
跳到nn.Module.add_module()的函数声明,注释里这样写到:
“这个新加入的模型可以用你给它的名字来获取到”
所以正确写法是:
print(LLL.L1.weight)