../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:367: operator(): block: [1227,0,0], thread: [90,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:367: operator(): block: [1227,0,0], thread: [91,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:367: operator(): block: [1227,0,0], thread: [92,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:367: operator(): block: [1227,0,0], thread: [93,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/ScatterGatherKernel.cu:367: operator(): block: [1227,0,0], thread: [94,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
0%| | 0/250 [00:01<?, ?it/s]
Traceback (most recent call last):
File "/media/sunjiakuo/78de387e-1660-4914-b70a-224cd2c430e9/xtx/teacher/train_unet.py", line 82, in <module>
train_loss, train_dice = trainer.train_epoch()
File "/media/sunjiakuo/78de387e-1660-4914-b70a-224cd2c430e9/xtx/teacher/SegTrainer.py", line 173, in train_epoch
loss = self.loss_seg(seg_pred_batch, mask_batch)
File "/home/sunjiakuo/.local/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/media/sunjiakuo/78de387e-1660-4914-b70a-224cd2c430e9/xtx/teacher/MyLoss.py", line 149, in forward
tp, fp, fn, _ = get_tp_fp_fn_tn(net_output, gt, axes, loss_mask, False)
File "/media/sunjiakuo/78de387e-1660-4914-b70a-224cd2c430e9/xtx/teacher/MyLoss.py", line 98, in get_tp_fp_fn_tn
y_onehot.scatter_(1, gt, 1)
RuntimeError: CUDA error: device-side assert triggered
这是一个越界错误,由于tensor索引出界所引起的。
我这里报错的起因是,我自己重新建立一个nii格式的数据集,在保存标签的时候用的是RGB即255表示灰度图的1。
因此这里使用的gt值是超出索引范围的值。
解决方案是在在y_onehot.scaater_(1,gt,1)前面插入
gt=gt.clamp(0,1)
这个函数的作用是将gt的范围保证在0-1区间,如果255则转换为1,从而解决了这个问题。
因为目前网上没有类似bug的解决方案,故贴在这里。
有关SoftDiceLoss可以参考
https://github.com/JunMa11/SegLoss/blob/master/test/nnUNetV2/loss_functions