版权归属:
- https://blog.csdn.net/halchan
- chanhal@outlook.com
更多关注:
- https://github.com/chanhal
- https://www.zhihu.com/people/chanhal
Introduction
在写pytorch代码时,发现并不是所有的写法都能将模块和参数注册到网络中,不信,看下面的代码:
class net(nn.Module):
def __init__(self):
super(net1, self).__init__()
self.linears = [nn.Linear(10,10) for i in range(2)]
def forward(self, x):
for m in self.linears:
x = m(x)
return x
mynet = net()
print(mynet)
返回:
net()
表示并没有注册将模块注册至网络中
而下面的两种写法可以将模块注册至网络中。
# 写法一
class net(nn.Module):
def __init__(self):
super(net1, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10,10) for i in range(2)])
def forward(self, x):
for m in self.linears:
x = m(x)
return x
mynet = net()
print(mynet)
返回:
net(
(linears): ModuleList(
(0): Linear(in_features=10, out_features=10, bias=True)
(1): Linear(in_features=10, out_features=10, bias=True)
)
)