系列文章目录
PyTorch学习——关于tensor、Variable、nn.Parameter()、叶子节点、非叶子节点、detach()函数、查看网络层参数
pytorch优化器——add_param_group()介绍及示例、Yolov7 优化器代码示例
文章目录
- 系列文章目录
-
- PyTorch学习——关于tensor、Variable、nn.Parameter()、叶子节点、非叶子节点、detach()函数、查看网络层参数 [pytorch优化器——add_param_group()介绍及示例、Yolov7 优化器代码示例](http://t.csdn.cn/WhKmC)
- Bug关键词:pytorch 模型自定义参数不更新、网络梯度为None,参数不更新解、tensor参数有梯度,但不更新、nn.Parameter()参数不更新。
- 1、Bug介绍
- 2、Bug复现、解决bug的心路历程
- 3、复现总代码
- 4、关于解决tensor参数梯度,weight不更新的若干思路的总结:
- 总结
Bug关键词:pytorch 模型自定义参数不更新、网络梯度为None,参数不更新解、tensor参数有梯度,但不更新、nn.Parameter()参数不更新。
1、Bug介绍
记录一个bug。
bug描述:这是关于一个Pytorch模型自定义nn.Parameter()参数不更新的bug。
就是用nn.Parameter()定义一个参数变量,它可以根据loss调整大小。
2、Bug复现、解决bug的心路历程
程序非常简单,但我发现定义的权重变量(self.w)并没有变化。
为了复现bug,这里我用了resnet18网络跑数字手写体mini数据集,网络模型如下:
class resnet18(torch.nn.Module):
def __init__(self):
super(resnet18, self).__init__()
self.block1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 10, 5),
torch.nn.MaxPool2d(2),
torch.nn.ReLU(True),
torch.nn.BatchNorm2d(10),
)
self.block2 = torch.nn.Sequential(
torch.nn.Conv2d(10, 20, 5),
torch.nn.MaxPool2d(2),
torch.nn.ReLU(True),
torch.nn.BatchNorm2d(20),
)
self.fc = torch.nn.Sequential(
torch.nn.Flatten(),
torch.nn.Linear(320, 10)
)
self.w = torch.nn.Parameter(torch.ones(2)) # 定义自学习参数
def forward(self, x):
x = self.block1(x)*self.w[0]
x = self.block2(x)*self.w[1]
x = self.fc(x)
return x
输出结果如下:
从上面的程序看逻辑上没有问题。程序正常运行,就是self.w没有变化。在网上查资料很多人说的是定义的变量不是叶子节点、或者没有求梯度、没有加入网络层等原因。我把这些不确定的原因都通过一些输出函数打印了出来。
结果:
是否是叶子节点:发现self.w在初始化的前两个输出是叶子节点,但后面就不是叶子节点了。
梯度:显卡单卡是可以求出梯度的,但梯度很小(例如0.0001),多卡运行梯度输出为None。
这里问题就出现了,我自己nn.Parameter()定义的变量正常来说跟网络层的权重一样是叶子节点,而且默认求梯度,为什么输出是非叶子节点且梯度为None。然后我开是蒙了。
关于叶子节点和梯度的介绍可以看这个:PyTorch学习——关于tensor、Variable、nn.Parameter()、叶子节点、非叶子节点、detach()函数、查看网络层参数。
之后我想是不是梯度太小加学习率太小。导致为0了。然后开始向学习率优化器方向入手。结果bug就是出现在了这里。想到这里就差不多知道bug出现的原因了,但并不是学习率太小。而是:
bug:自己定义的self.w参数并没有加入到优化器中迭代。 很多网络是直接对全局定义,那么自己定义的self.w参数自然加入了优化器。比如这样:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.5)
而我用的是yolo网络,优化器函数如下:
self.w参数没有加入了优化器的迭代行列自然不会根据损失调整大小。
#---------------------------------------------------------------------#
# 构造损失函数和优化函数
# 损失
criterion = torch.nn.CrossEntropyLoss()
pg0, pg1, pg2 ,pg3= [], [], [], [