network.apply(func)
——在每个子模组递归地执行func
——一般用于初始化参数中
@torch.no_grad()
def init_weights(m):
if type(m) == nn.Linear:
m.weight.fill_(1.0)
m.bias.fill_(0)
if type(m)==nn.Conv2d:
m.weight.fill_(4.9)
net = nn.Sequential(nn.Linear(2, 2), nn.Conv2d(2,2,1))
net.apply(init_weights)
for i in net.parameters():
print(i)
'''
tensor([[1., 1.],
[1., 1.]], requires_grad=True)
Parameter containing:
tensor([0., 0.], requires_grad=True)
Parameter containing:
tensor([[[[4.9000]],
[[4.9000]]],
[[[4.9000]],
[[4.9000]]]], requires_grad=True)
Parameter containing:
tensor([0.3902, 0.0678], requires_grad=True)
'''