根据上次经验,估计还是切割数据的时候出了问题,去检查一下!
print(torch.isnan(out).any())
四个数据,需要分别检查,out_cnt是否和out一致,再检查一下out_cnt中是否有nan
首先检查发现,数据大小是对的。然后查一下是否有nan
gg,和上次错误可能不一样,训练中没有nan的成分啊。
重新回到训练的位置看哪里出了nan
这说明模型预测的数据中有nan
这就有些恶心了,估计是前向模型出了问题
我们可以发现是第13步出现了问题。我们将模型一步步回推到13步,观察前向模型的bug是什么。
可以发现此时整个x都是nan,所以也不奇怪了
可以发现此时out的初始数据已经有nan了。
我们发现初始数据是没有nan的。经过核实,发现再经过网络前也是没有nan也没有inf的。基本可以确定的确是网络输出了一堆nan。那么在探讨为什么网络会输出nan之前,我们先看看能否跳过这些nan
似乎是不行的,因为有太多的nan了。而且损失也卡住了。
我们开始怀疑是否是因为没有随机打乱---因为一开始数据可能并没有什么信息。 这可能导致一开始网络就训练炸了(表示怀疑,也不至于一直没有信息)。
这种情况仍然会出现。
考虑到实际情况中没有改动网络结构(网络就是一个全连接),仍然怀疑是因为数据切割出错导致的问题。
但是这个数据实在是太无辜了,最大值最小值平均值都很正常,为啥过了网络突然不正常了呢?
观察一下之前的结果吧
说实话,和之前一样的13步炸我是没想到的---因为我是有过随机的。这不合理呀,什么模型会就在第13步突然炸掉。我换个随机数种子试试
可以看到,虽然换了个随机数种子,但是输入输出并没有任何变化。这说明根本没有随机起来。回去验证随机环节。发现的确写的随机是假的。
不过那就意味着之前写的随机也一直是假的。两个都是假的那就是真的。但是之前是能够跑起来的。所以现在发生了什么导致又跑不起来了呢?
不管怎么样,这次反正调好了随机数。接着我们看结果。
说明随机数没了问题。但是网络原来的问题仍然存在。我们没办法怀疑网络,我仍然认为是因为之前的数据“离谱”导致了网络被带偏,因此,我再次回去检查数据切割 。
print(torch.isnan(out).any())没用到。
我仍然怀疑数据切割的问题。这次可能是另一个原因:切割的时候其它维度出了问题。但好像也不对---因为其它维度如果出问题应该会报错。
一个最愚蠢但可能有效的办法就是一层层看,到底什么时候变nan了。
print(out.max(),out.min())
我们是45层炸的,44step的时候还是一切如常
print(out.max(),out.min())
print(self.linear1.weight.max(),self.linear1.weight.min())
print(self.linear2.weight.max(),self.linear2.weight.min())
这还是很正常的数据
然而,第45层就出现了nan
和数据没有关系,应该就是上一次梯度回传的时候,把数据传炸了。
确定是再第44次梯度传炸后,可以具体看到传导的过程
此时模型还一切正常
更新完毕后模型中已经出现了nan
我重新写了一个代码,基本没改任何东西,仍然有nan!这会不会是最初始数据就有问题?
可惜最初版代码被我直接改了!还是太自信了。
从LSTM的版本开始改吧。跑了200轮,应该不会nan了。现在将它改成全连接
看样子现在也没有Nan。现在先跑着。我们看更改数据是不是带来这个nan的原因。
测试数据变成五维,再看,还是可以训练的,不出nan。排除matlab数据导入错的问题
接着尝试修改更多。不再使用验证集
nan了! 不再划分训练测试导致nan的唯一可能就是之前已经排除掉的那种!
我们重新回到今天一直nan的版本,试图恢复验证集,观察结果发现这次它不nan了!说明其它地方都可能没有问题,还是数量没有数对,导致它nan了。好,那为了更好的解决问题,我们回到初始版本,观察到底哪里没有数对。
如果这是正确的话,说明之前数据读取一直有问题,只是由于验证机的数据并不会反传才导致了之前并没有把模型训炸掉。似乎一切安好。但实际上,这些末尾的数据时时刻刻打算将模型搞炸。
为了验证这一点,我将原始数据打乱,发现仍然出现开始就nan的情况。这说明,的确存在着一些数据会让模型nan掉。由于之前没有打乱数据,验证集数据会被放到最后4个人中。这也就意味数据就是存在着问题,只是之前没有发现。
接着,我又将损失函数给改成了LMS而不是相关系数。
发现这个问题居然不再出现了?所以使用相关系数作为损失函数才是导致nan的原因???难道是因为相关系数的梯度算炸了?说实话,还真存在这种可能性。
这个结果虽然很差,但是至少说明一点,就是,的确出现了部分数据无法使用相关系数作为损失函数。换言之,这次可能真的是出现了梯度爆炸。
最后,才在师弟的帮助下找到了问题所在:损失函数写炸了
之前每个batch只返回一个值,这会让每个batch只有一个值有效。师弟猜测,除了返回的值,其它值都会被补0导致回传梯度爆炸。
损失函数只要当作一个数据去跑就行了!