今天在处理ACDC的数据的时候,我想把标签的命名和之前MMWHS命名一致,但是后续在计算dice损失的时候一直如下的错误
RuntimeError: CUDA error: device-side assert triggered CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect。
一直报cuda断言的错误,找不到具体的错误。然后查看资料之后,把device设置为cpu,就可以查看具体的错误,发现是计算dice,将y变成独热向量的时候出错。
报错把y_onehot沿着y写的时候,出现越界的错误。查了一下scatter_()函数的实现是这样的。pytorch中的 scatter_()函数使用和详解_pytorch scatter-CSDN博客
y = y.scatter(dim,index,src)
#则结果为:
y[ index[i][j][k] ] [j][k] = src[i][j][k] # if dim == 0
y[i] [ index[i][j][k] ] [k] = src[i][j][k] # if dim == 1
y[i][j] [ index[i][j][k] ] = src[i][j][k] # if dim == 2
因为ACDC算上背景一共四个类,但是在和MMWHS一致时有一个类是5,越界了。所以需要对原来的数据集进行重写处理和划分,所以在后续处理多个数据集不同的类别需要有一定的规范。