错误信息如下
../aten/src/ATen/native/cuda/NLLLoss2d.cu:93: nll_loss2d_forward_kernel: block: [0,0,0], thread: [xxx,0,0] Assertion `t >= 0 && t < n_classes` failed.
运行背景: Linux/Windows 都有, 我是先用nnUnet预处理了医学数据,由于它有b好多不同的类别,也就是多数据集联合训练,有A,B,C三种类别(A 不等于B 不等于 C), 直接读取NifTi是没有问题的,读取预处理的npz格式数据,开始报错,本来以为是不是类别传错了,缩减到一个数据集也是报错. 最后注释掉CrossEntropy()就不报错了,找到了报错的代码块.
它指明了类别是不能为负数,也不能超过指定的类别,后者肯定不可能. 我使用 np.unique(),查了一下类别,发现里面是存在一个 -1的,发现nnUnet会把不属于0背景的其他背景会归为-1,你使用CrossEntropy就会报负类别错误.
解决方法:
nn.CrossEntropyLoss(weight=YOURWEIGHT,ignore_index=-1)
把负数类别忽略掉,用GPT查了一下如果忽略多个类别,应该这样:
ignored_classes = [class_idx1, class_idx2, class_idx3, ...] # 要忽略的类别的索引列表
weight = torch.ones(n).to(device) # 权重设置为1,可以根据需要进行调整
criterion = nn.CrossEntropyLoss(weight=weight, ignore_index=ignored_classes)
总结:
- 解决BUG的时候,按照它提示的,它说你类别多了,类别是负数了它不能deal with的时候,就把类别输出看看,有没有错误
- 前后因果分析:为什么第一次用这个数据可以,第二次用了第一次预处理后的数据就报错?那肯定就是在预处理阶段, 让你的数据变得不可被一些函数处理,那么就看看预处理后的数据发生了什么问题.
- 文档需要极度熟悉, 多看看torch的文档,多看看nnUnet开发者文档,这几天配WSL和nnUNet发现,网上很多的帖子,就是你抄我的,这个平台抄那个平台, 其实最终微软的WSL配置和nnUNet官网的readme,就是很清楚的.
- 网上的BUG一定是在某种情况下发生的, 如果没有写运行背景的,一定多察觉,他肯定和你不是一个环境下运行的,不是同一时间运行的,多考虑一下他和你运行的差别,一般官方会有讨论区,比如GitHub的issues.
- 这个总结写给我,也写给遇到BUG的每个朋友
如果有补充和修正, 欢迎指正