Pytorch优化模型参数时,碰到一个报错, 如下:
从报错类型来看,不能优化非叶子节点的参数。
把model的参数打出来看看,如下图:
for name, parameter in model.named_parameters():
print('%-40s%-20s%s' %(name, parameter.requires_grad, parameter.is_leaf))
发现的确有两个参数为非叶子节点,同时其requires_grad=True,因此报错。
解决办法:先将这种参数detach掉【requires_grad=False】,然后再喂给优化器,如果确实需要优化这种参数,再调整其requires_grad=True,如下:
for name, parameter in model.named_parameters():
if not parameter.is_leaf:
parameter.detach_()
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
for name, parameter in model.named_parameters():
parameter.requires_grad = True
如此,问题解决,即可正常优化模型参数了。